mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
fix issues raised by Coverity scans (#7431)
This commit combines fixes for 37 potential code issues found in Coverity scans. the issues include but are not limited to potential access to uninitialized variables, dead and redundant code. We understand that reviewing such a commit can be difficult and will be happy to help with any questions or changes required. --------- Signed-off-by: Nir Sonnenschein <nsonnenschein@habana.ai> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -67,7 +67,7 @@ def get_accelerator():
|
||||
f"XPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
|
||||
elif accelerator_name == "xpu.external":
|
||||
try:
|
||||
import intel_extension_for_deepspeed # noqa: F401 # type: ignore
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401 # type: ignore
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
|
||||
@ -224,6 +224,12 @@ def get_accelerator():
|
||||
ds_accelerator = CPU_Accelerator()
|
||||
elif accelerator_name == "xpu.external":
|
||||
# XPU_Accelerator is already imported in detection stage
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F811
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
|
||||
)
|
||||
ds_accelerator = XPU_Accelerator()
|
||||
elif accelerator_name == "xpu":
|
||||
from .xpu_accelerator import XPU_Accelerator
|
||||
@ -258,7 +264,7 @@ def get_accelerator():
|
||||
def set_accelerator(accel_obj):
|
||||
global ds_accelerator
|
||||
_validate_accelerator(accel_obj)
|
||||
if accel_logger is not None:
|
||||
if accel_logger is not None and accel_obj is not None:
|
||||
accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
|
||||
ds_accelerator = accel_obj
|
||||
|
||||
|
@ -81,7 +81,7 @@ class Autotuner:
|
||||
if not os.path.exists(self.results_dir):
|
||||
try:
|
||||
os.makedirs(self.results_dir, exist_ok=True)
|
||||
logger.info(f"Created autotuning results directory: {self.exps_dir}")
|
||||
logger.info(f"Created autotuning results directory: {self.results_dir}")
|
||||
except:
|
||||
logger.error(
|
||||
f"Failed to create {self.results_dir}, please check results_dir in the autotuning config file is accessible by all the nodes in the job."
|
||||
|
@ -144,7 +144,7 @@ DEFAULT_MIN_MEM_CONFIG = {
|
||||
"zero_optimization": {
|
||||
"stage": 3
|
||||
},
|
||||
"memory_break_down": False
|
||||
"memory_breakdown": False
|
||||
}
|
||||
|
||||
DEFAULT_TUNING_SPACE_ZERO_0 = {"zero_optimization": {"stage": 0}}
|
||||
|
@ -77,27 +77,12 @@ class CCLBackend(TorchBackend):
|
||||
return CCLHandler(self.ccl_comm_op)
|
||||
|
||||
def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
use_caching = False
|
||||
if use_caching:
|
||||
match_id = f"{tensor.size()}-{op}"
|
||||
name = "all_reduce_caching"
|
||||
if name in self.available_coll:
|
||||
group = self.get_all_ranks_from_group(group)
|
||||
return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
|
||||
else:
|
||||
return self.run_collective(name=name,
|
||||
tensor=tensor,
|
||||
op=op,
|
||||
match_id=match_id,
|
||||
group=group,
|
||||
async_op=async_op)
|
||||
name = "all_reduce"
|
||||
if name in self.available_coll:
|
||||
group = self.get_all_ranks_from_group(group)
|
||||
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
|
||||
else:
|
||||
name = "all_reduce"
|
||||
if name in self.available_coll:
|
||||
group = self.get_all_ranks_from_group(group)
|
||||
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
|
||||
else:
|
||||
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
|
||||
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None):
|
||||
name = "inference_all_reduce"
|
||||
|
@ -101,7 +101,7 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m
|
||||
with graph.inserting_after(reload_node):
|
||||
wait_node = graph.create_node('call_function',
|
||||
torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {},
|
||||
name=f"wait_copy_{node.name}_{val_id}")
|
||||
name=f"wait_copy_{reload_node.name}_{val_id}")
|
||||
|
||||
# replace all uses of node with wait_node
|
||||
users = {}
|
||||
|
@ -137,7 +137,7 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
|
||||
else:
|
||||
new_module = None
|
||||
|
||||
if compression_technique is not None:
|
||||
if compression_technique is not None and new_module is not None:
|
||||
for k, v in compression_technique.items():
|
||||
if k == SPARSE_PRUNING:
|
||||
if v[SPARSE_PRUNING_ENABLED]:
|
||||
|
@ -37,12 +37,16 @@ def shard_qkv_param(param: torch.Tensor,
|
||||
if n_heads_kv is not None and n_heads_q is None:
|
||||
raise ValueError("n_heads_kv should not be passed without n_heads_q")
|
||||
|
||||
if param is None:
|
||||
raise ValueError("param should not be None")
|
||||
if n_heads_q is None:
|
||||
# Guaranteed to be in MHA
|
||||
if param.shape[0] // 3 % head_size != 0:
|
||||
raise ValueError("MHA param shape is not correct")
|
||||
n_heads_q = param.shape[0] // head_size // 3
|
||||
mha_sharding = True
|
||||
elif n_heads_kv is None:
|
||||
mha_sharding = True
|
||||
else:
|
||||
mha_sharding = n_heads_q == n_heads_kv
|
||||
|
||||
@ -73,9 +77,6 @@ def shard_qkv_param(param: torch.Tensor,
|
||||
else:
|
||||
even_kv_sharding = n_heads_kv >= num_shards
|
||||
|
||||
if param is None:
|
||||
return None
|
||||
|
||||
q_param = param[:head_size * n_heads_q]
|
||||
kv_param = param[head_size * n_heads_q:]
|
||||
|
||||
|
@ -122,9 +122,9 @@ class DSSequenceDescriptor(BaseSequenceDescriptor):
|
||||
|
||||
self._seen_tokens = 0
|
||||
self._in_flight_tokens = 0
|
||||
assert kv_cache_ids_shadow is not None # add check before use
|
||||
|
||||
self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0]
|
||||
for kv_cache_ids_shadow in kv_cache_ids_shadow)
|
||||
self._num_allocation_groups = tuple(kv_cache_id.shape[0] for kv_cache_id in kv_cache_ids_shadow)
|
||||
self._blocks_per_allocation_group = tuple(
|
||||
torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups)
|
||||
|
||||
|
@ -73,6 +73,8 @@ class MegatronLayerPolicy(TransformerPolicy):
|
||||
attention = self.client_module.attention
|
||||
else:
|
||||
attention = self.client_module.self_attention
|
||||
else:
|
||||
return None
|
||||
|
||||
return attention.query_key_value.weight, \
|
||||
attention.query_key_value.bias, \
|
||||
|
@ -93,8 +93,10 @@ def generic_injection(module, dtype=None, enable_cuda_graph=True):
|
||||
return child
|
||||
if len(policy_attn) == 5:
|
||||
qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
|
||||
qw, kw, vw = torch.empty(0), torch.empty(0), torch.empty(0)
|
||||
else:
|
||||
qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
|
||||
qkvw = torch.empty(0)
|
||||
|
||||
config = transformer_inference.DeepSpeedInferenceConfig(
|
||||
hidden_size=hidden_size,
|
||||
@ -113,11 +115,15 @@ def generic_injection(module, dtype=None, enable_cuda_graph=True):
|
||||
return data
|
||||
|
||||
if len(policy_attn) == 5:
|
||||
assert qkvw is not None and qkvw.data is not None, "qkvw can't be None"
|
||||
attn_module.attn_qkvw.data = transpose(qkvw.data)
|
||||
else:
|
||||
attn_module.attn_qkvw = None
|
||||
assert qw is not None and qw.data is not None, "qw can't be None"
|
||||
attn_module.attn_qw.data = transpose(qw.data)
|
||||
assert kw is not None and kw.data is not None, "kw can't be None"
|
||||
attn_module.attn_kw.data = transpose(kw.data)
|
||||
assert vw is not None and vw.data is not None, "vw can't be None"
|
||||
attn_module.attn_vw.data = transpose(vw.data)
|
||||
|
||||
attn_module.attn_qkvb = None
|
||||
@ -316,21 +322,15 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
||||
return _autotp._replace_module(module)
|
||||
|
||||
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
|
||||
training = False # todo: refactor this part to go in the config
|
||||
if training:
|
||||
# copy relevant state from child -> new module
|
||||
new_module = replace_with_policy(child, _policy, config.triangular_masking)
|
||||
|
||||
# copy relevant state from child -> new module
|
||||
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
|
||||
new_module = replace_with_policy(child,
|
||||
_policy,
|
||||
config.triangular_masking,
|
||||
inference=True,
|
||||
layer_id=layer_id)
|
||||
else:
|
||||
# copy relevant state from child -> new module
|
||||
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
|
||||
new_module = replace_with_policy(child,
|
||||
_policy,
|
||||
config.triangular_masking,
|
||||
inference=True,
|
||||
layer_id=layer_id)
|
||||
else:
|
||||
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
|
||||
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
|
||||
|
||||
return new_module
|
||||
|
||||
|
@ -400,7 +400,7 @@ def topkgating(
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = torch.mean(me * ce) * num_experts * num_experts / k
|
||||
|
||||
locations = None
|
||||
if drop_tokens:
|
||||
# Calculate configured capacity and remove locations outside capacity from mask
|
||||
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
|
||||
@ -437,6 +437,8 @@ def topkgating(
|
||||
denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps)
|
||||
gates_masked = gates_masked / denom_s
|
||||
|
||||
if locations is None:
|
||||
raise ValueError(f"Locations is not set: {locations}")
|
||||
# dispatch_mask
|
||||
locations_sc = _one_hot_to_float((locations * mask), capacity)
|
||||
|
||||
|
@ -128,18 +128,18 @@ def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stri
|
||||
inc_b = TK * stride_kb
|
||||
else:
|
||||
pinc += 2
|
||||
if meta['DSD']:
|
||||
inc_b = tl.load(pinc)
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = inc_b * stride_kb
|
||||
if meta['DDS']:
|
||||
inc_a = tl.load(pinc)
|
||||
inc_b = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = inc_a * stride_ka
|
||||
if meta['DSD']:
|
||||
inc_b = tl.load(pinc)
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = inc_b * stride_kb
|
||||
if meta['DDS']:
|
||||
inc_a = tl.load(pinc)
|
||||
inc_b = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = inc_a * stride_ka
|
||||
pa += inc_a
|
||||
pb += inc_b
|
||||
# pre-fetch
|
||||
|
@ -57,7 +57,7 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module):
|
||||
self.attn_2.do_out_bias = False
|
||||
self.attn_2_bias = self.attn_2.attn_ob
|
||||
else:
|
||||
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False)
|
||||
self.attn_2_bias = nn.Parameter(torch.zeros_like(self.norm3_g), requires_grad=False)
|
||||
|
||||
self.gated_activation = GatedActivationOp()
|
||||
self.layer_norm = LayerNormOp()
|
||||
|
@ -335,27 +335,29 @@ class DeepSpeedTransformerLayer(nn.Module):
|
||||
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
|
||||
self.init_transformer_weights(self.config.adjust_init_range)
|
||||
else:
|
||||
# For testing only.
|
||||
q = initial_weights[0].data
|
||||
k = initial_weights[1].data
|
||||
v = initial_weights[2].data
|
||||
if initial_weights is not None:
|
||||
# For testing only.
|
||||
q = initial_weights[0].data
|
||||
k = initial_weights[1].data
|
||||
v = initial_weights[2].data
|
||||
|
||||
self.attn_qkvw = nn.Parameter(torch.cat((q, k, v)))
|
||||
#self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
|
||||
# initial_weights[i].clone()
|
||||
#torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data)
|
||||
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
|
||||
self.attn_qkvb.data.zero_()
|
||||
self.attn_ow = initial_weights[3]
|
||||
self.attn_ob = initial_biases[3]
|
||||
self.attn_nw = initial_weights[4]
|
||||
self.attn_nb = initial_biases[4]
|
||||
self.inter_w = initial_weights[5]
|
||||
self.inter_b = initial_biases[5]
|
||||
self.output_w = initial_weights[6]
|
||||
self.output_b = initial_biases[6]
|
||||
self.norm_w = initial_weights[7]
|
||||
self.norm_b = initial_biases[7]
|
||||
self.attn_qkvw = nn.Parameter(torch.cat((q, k, v)))
|
||||
#self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
|
||||
# initial_weights[i].clone()
|
||||
#torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data)
|
||||
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
|
||||
self.attn_qkvb.data.zero_()
|
||||
self.attn_ow = initial_weights[3]
|
||||
self.attn_nw = initial_weights[4]
|
||||
self.inter_w = initial_weights[5]
|
||||
self.output_w = initial_weights[6]
|
||||
self.norm_w = initial_weights[7]
|
||||
if initial_biases is not None:
|
||||
self.attn_ob = initial_biases[3]
|
||||
self.attn_nb = initial_biases[4]
|
||||
self.inter_b = initial_biases[5]
|
||||
self.output_b = initial_biases[6]
|
||||
self.norm_b = initial_biases[7]
|
||||
|
||||
# Load cuda modules if needed
|
||||
global transformer_cuda_module, stochastic_transformer_cuda_module
|
||||
|
@ -715,7 +715,7 @@ class DistributedDataAnalyzer(object):
|
||||
buffer = torch.cat(tensor_list, dim=0).to(self.device)
|
||||
write_buffer_to_file(buffer, 0, builder)
|
||||
elif self.worker_id == 0 and src > 0: # rank 0 receives other rank's data and writes it
|
||||
buffer = torch.empty(sizes[src].item(), dtype=buffer.dtype, device=buffer.device)
|
||||
buffer = torch.empty(sizes[src].item(), dtype=numpy_dtype, device=self.device)
|
||||
err = dist.recv(buffer, src=src, group=self.comm_group, tag=src)
|
||||
assert err == src and len(buffer) > 0, "recv failed"
|
||||
write_buffer_to_file(buffer, src, builder)
|
||||
|
@ -407,7 +407,6 @@ class DeepSpeedEngine(Module):
|
||||
for _, module in self.module.named_modules():
|
||||
if isinstance(module, LoRAOptimizedLinear):
|
||||
self.optimized_linear_lora_enabled = True
|
||||
offload_ratio = None
|
||||
if offload_ratio is not None:
|
||||
assert offload_ratio == module.lora_config.offload_ratio, \
|
||||
"all lora_config offload ratios should be the same across the model"
|
||||
@ -1262,10 +1261,6 @@ class DeepSpeedEngine(Module):
|
||||
@staticmethod
|
||||
def __check_params(model: Module, dtype: torch.dtype) -> None:
|
||||
return
|
||||
if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
|
||||
raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
|
||||
f"not {dtype}: "
|
||||
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
|
||||
|
||||
def _set_client_model(self, model):
|
||||
# register client model in _modules so that nn.module methods work correctly
|
||||
|
@ -13,7 +13,6 @@ from deepspeed import comm as dist
|
||||
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.timer import ThroughputTimer
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
|
||||
|
||||
from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
|
||||
@ -712,7 +711,6 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
|
||||
def _exec_forward_pass(self, buffer_id):
|
||||
self.tput_timer.start()
|
||||
self.mem_status('BEFORE FWD', reset_max=True)
|
||||
|
||||
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
|
||||
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
|
||||
@ -808,13 +806,10 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
assert self.optimizer is not None, "must provide optimizer during " \
|
||||
"init in order to use backward"
|
||||
|
||||
self.mem_status('BEFORE BWD', reset_max=True)
|
||||
|
||||
# The last stage just runs backward on the loss using DeepSpeed's typical
|
||||
# mechanisms.
|
||||
if self.is_last_stage():
|
||||
super().backward(self.loss)
|
||||
self.mem_status('AFTER BWD')
|
||||
return
|
||||
|
||||
outputs = self.pipe_buffers['outputs'][buffer_id]
|
||||
@ -881,8 +876,6 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
self.timers(BACKWARD_MICRO_TIMER).stop()
|
||||
self.timers(BACKWARD_GLOBAL_TIMER).stop()
|
||||
|
||||
self.mem_status('AFTER BWD')
|
||||
|
||||
def _exec_load_micro_batch(self, buffer_id):
|
||||
if self.wall_clock_breakdown():
|
||||
self.timers(BATCH_INPUT_TIMER).start()
|
||||
@ -1221,14 +1214,11 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
if self.wall_clock_breakdown():
|
||||
self.timers(STEP_MICRO_TIMER).start()
|
||||
self.timers(STEP_GLOBAL_TIMER).start()
|
||||
self.mem_status('BEFORE STEP', reset_max=True)
|
||||
|
||||
self._force_grad_boundary = True
|
||||
self._take_model_step(lr_kwargs)
|
||||
self._force_grad_boundary = False
|
||||
|
||||
self.mem_status('AFTER STEP')
|
||||
|
||||
if self.global_rank == 0 and self.monitor.enabled:
|
||||
self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
|
||||
if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
|
||||
@ -1307,53 +1297,6 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
"""Disabled for pipeline parallel training. See ``train_batch()``. """
|
||||
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
|
||||
|
||||
def mem_status(self, msg, print_rank=-1, reset_max=False):
|
||||
return
|
||||
global mem_alloced, mem_cached
|
||||
if not self.global_steps == 0 or not self.global_steps == 9:
|
||||
#return
|
||||
pass
|
||||
if self.mpu.get_data_parallel_rank() != 0:
|
||||
return
|
||||
|
||||
if self.global_rank != 0:
|
||||
return
|
||||
|
||||
rank = self.global_rank
|
||||
if print_rank != -1 and rank != print_rank:
|
||||
return
|
||||
|
||||
get_accelerator().synchronize()
|
||||
|
||||
if reset_max:
|
||||
get_accelerator().reset_max_memory_cached()
|
||||
get_accelerator().reset_max_memory_allocated()
|
||||
|
||||
new_alloced = get_accelerator().memory_allocated()
|
||||
new_cached = get_accelerator().memory_cached()
|
||||
|
||||
delta_alloced = new_alloced - mem_alloced
|
||||
delta_cached = new_cached - mem_cached
|
||||
|
||||
mem_cached = new_cached
|
||||
mem_alloced = new_alloced
|
||||
|
||||
max_alloced = get_accelerator().max_memory_allocated()
|
||||
max_cached = get_accelerator().max_memory_cached()
|
||||
|
||||
# convert to GB for printing
|
||||
new_alloced /= 1024**3
|
||||
new_cached /= 1024**3
|
||||
delta_alloced /= 1024**3
|
||||
delta_cached /= 1024**3
|
||||
max_alloced /= 1024**3
|
||||
max_cached /= 1024**3
|
||||
|
||||
print(
|
||||
f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
|
||||
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
|
||||
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
|
||||
|
||||
def module_state_dict(self, exclude_frozen_parameters=False):
|
||||
"""Override hack to save a pipe model and return the directory path of the save.
|
||||
|
||||
|
@ -15,10 +15,10 @@ class SparseTensor(object):
|
||||
|
||||
def __init__(self, dense_tensor=None):
|
||||
self.orig_dense_tensor = dense_tensor
|
||||
self.dtype = self.orig_dense_tensor.dtype
|
||||
self.is_sparse = dense_tensor.is_sparse
|
||||
if dense_tensor is not None:
|
||||
if dense_tensor.is_sparse:
|
||||
self.is_sparse = dense_tensor.is_sparse
|
||||
self.dtype = self.orig_dense_tensor.dtype
|
||||
if self.is_sparse:
|
||||
dense_tensor = dense_tensor.coalesce()
|
||||
self.indices = dense_tensor.indices().flatten()
|
||||
self.values = dense_tensor.values()
|
||||
|
@ -1656,16 +1656,11 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
|
||||
|
||||
if rank is None:
|
||||
# "All Reducing"
|
||||
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
|
||||
else:
|
||||
global_rank = dist.get_global_rank(self.dp_process_group, rank)
|
||||
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
|
||||
# "All Reducing"
|
||||
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
|
||||
|
||||
if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
|
||||
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
|
||||
tensor.copy_(tensor_to_allreduce)
|
||||
tensor.copy_(tensor_to_allreduce)
|
||||
|
||||
return tensor
|
||||
|
||||
|
@ -1346,10 +1346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
|
||||
|
||||
grad_accum = self.get_param_gradient_attribute(param)
|
||||
if grad_accum is None:
|
||||
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
|
||||
else:
|
||||
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
|
||||
assert grad_accum is not None
|
||||
|
||||
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
|
||||
if not self.fp16_master_weights_and_gradients:
|
||||
src_tensor = src_tensor.float()
|
||||
|
||||
|
@ -124,7 +124,7 @@ class FPDT_InputConstruct(torch.nn.Module):
|
||||
load_balanced_tokens = self.tokens[:, indices]
|
||||
load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels
|
||||
|
||||
load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else self.attention_mask
|
||||
load_balanced_attention_mask = self.attention_mask
|
||||
load_balanced_position_ids = self.position_ids[:,
|
||||
indices] if self.position_ids is not None else self.position_ids
|
||||
|
||||
|
@ -404,7 +404,7 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dp_world_size = mpu.get_data_parallel_world_size()
|
||||
dp_world_size = _get_data_parallel_world_size()
|
||||
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
|
||||
|
||||
_ensure_divisibility(world_size, tensor_parallel_size_)
|
||||
@ -569,31 +569,37 @@ def _get_data_parallel_group_ranks():
|
||||
|
||||
|
||||
def _get_broadcast_src_rank():
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_global_rank(_get_sequence_data_parallel_group(), 0)
|
||||
|
||||
|
||||
def _get_expert_broadcast_src_rank(group_name):
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
|
||||
|
||||
|
||||
def _get_expert_parallel_world_size(group_name):
|
||||
"""Return world size for the expert parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_world_size(group=_get_expert_parallel_group(group_name))
|
||||
|
||||
|
||||
def _get_expert_data_parallel_world_size(group_name):
|
||||
"""Return world size for the expert data parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
|
||||
|
||||
|
||||
def _get_expert_parallel_rank(group_name):
|
||||
"""Return my rank for the expert parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_rank(group=_get_expert_parallel_group(group_name))
|
||||
|
||||
|
||||
def _get_expert_parallel_src_rank(group_name):
|
||||
"""Calculate the global rank corresponding to a local rank zero
|
||||
in the expert parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
global_rank = dist.get_rank()
|
||||
local_world_size = _get_expert_parallel_world_size(group_name)
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
@ -601,11 +607,13 @@ def _get_expert_parallel_src_rank(group_name):
|
||||
|
||||
def _get_expert_data_parallel_rank(group_name):
|
||||
"""Return my rank for the expert data parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
|
||||
|
||||
|
||||
def _get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
if mesh_device is not None:
|
||||
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
|
||||
global mpu
|
||||
@ -627,11 +635,13 @@ def _get_model_parallel_world_size():
|
||||
|
||||
def _get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
return dist.get_rank(group=_get_data_parallel_group())
|
||||
|
||||
|
||||
def _get_sequence_parallel_world_size():
|
||||
"""Return world size for the sequence parallel group."""
|
||||
"""Return world size for the model parallel group."""
|
||||
assert dist.is_initialized(), 'dist is not initialized'
|
||||
global mpu
|
||||
if mesh_device is not None:
|
||||
return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))
|
||||
|
@ -431,7 +431,6 @@ class OpBuilder(ABC):
|
||||
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
|
||||
"only cpu ops can be compiled!")
|
||||
return '-D__DISABLE_CUDA__'
|
||||
return '-D__DISABLE_CUDA__'
|
||||
|
||||
def _backup_cpuinfo(self):
|
||||
# Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
|
||||
|
@ -39,7 +39,6 @@ class CCLCommBuilder(CPUOpBuilder):
|
||||
raise ValueError(
|
||||
"Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
return ['-lccl', f'-L{ccl_root_path}/lib']
|
||||
|
||||
|
Reference in New Issue
Block a user