fix issues raised by Coverity scans (#7431)

This commit combines fixes for 37 potential code issues found in
Coverity scans.
the issues include but are not limited to potential access to
uninitialized variables, dead and redundant code.
We understand that reviewing such a commit can be difficult and will be
happy to help with any questions or changes required.

---------

Signed-off-by: Nir Sonnenschein <nsonnenschein@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Nir Sonnenschein
2025-08-02 19:16:10 +03:00
committed by GitHub
parent 0e51e09396
commit 1a8ad24f0d
24 changed files with 100 additions and 162 deletions

View File

@ -67,7 +67,7 @@ def get_accelerator():
f"XPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
elif accelerator_name == "xpu.external":
try:
import intel_extension_for_deepspeed # noqa: F401 # type: ignore
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
@ -224,6 +224,12 @@ def get_accelerator():
ds_accelerator = CPU_Accelerator()
elif accelerator_name == "xpu.external":
# XPU_Accelerator is already imported in detection stage
try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F811
except ImportError as e:
raise ValueError(
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
)
ds_accelerator = XPU_Accelerator()
elif accelerator_name == "xpu":
from .xpu_accelerator import XPU_Accelerator
@ -258,7 +264,7 @@ def get_accelerator():
def set_accelerator(accel_obj):
global ds_accelerator
_validate_accelerator(accel_obj)
if accel_logger is not None:
if accel_logger is not None and accel_obj is not None:
accel_logger.info(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
ds_accelerator = accel_obj

View File

@ -81,7 +81,7 @@ class Autotuner:
if not os.path.exists(self.results_dir):
try:
os.makedirs(self.results_dir, exist_ok=True)
logger.info(f"Created autotuning results directory: {self.exps_dir}")
logger.info(f"Created autotuning results directory: {self.results_dir}")
except:
logger.error(
f"Failed to create {self.results_dir}, please check results_dir in the autotuning config file is accessible by all the nodes in the job."

View File

@ -144,7 +144,7 @@ DEFAULT_MIN_MEM_CONFIG = {
"zero_optimization": {
"stage": 3
},
"memory_break_down": False
"memory_breakdown": False
}
DEFAULT_TUNING_SPACE_ZERO_0 = {"zero_optimization": {"stage": 0}}

View File

@ -77,27 +77,12 @@ class CCLBackend(TorchBackend):
return CCLHandler(self.ccl_comm_op)
def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
name = "all_reduce_caching"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
else:
return self.run_collective(name=name,
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
name = "all_reduce"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
else:
name = "all_reduce"
if name in self.available_coll:
group = self.get_all_ranks_from_group(group)
return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None):
name = "inference_all_reduce"

View File

@ -101,7 +101,7 @@ def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], m
with graph.inserting_after(reload_node):
wait_node = graph.create_node('call_function',
torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")
name=f"wait_copy_{reload_node.name}_{val_id}")
# replace all uses of node with wait_node
users = {}

View File

@ -137,7 +137,7 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
else:
new_module = None
if compression_technique is not None:
if compression_technique is not None and new_module is not None:
for k, v in compression_technique.items():
if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]:

View File

@ -37,12 +37,16 @@ def shard_qkv_param(param: torch.Tensor,
if n_heads_kv is not None and n_heads_q is None:
raise ValueError("n_heads_kv should not be passed without n_heads_q")
if param is None:
raise ValueError("param should not be None")
if n_heads_q is None:
# Guaranteed to be in MHA
if param.shape[0] // 3 % head_size != 0:
raise ValueError("MHA param shape is not correct")
n_heads_q = param.shape[0] // head_size // 3
mha_sharding = True
elif n_heads_kv is None:
mha_sharding = True
else:
mha_sharding = n_heads_q == n_heads_kv
@ -73,9 +77,6 @@ def shard_qkv_param(param: torch.Tensor,
else:
even_kv_sharding = n_heads_kv >= num_shards
if param is None:
return None
q_param = param[:head_size * n_heads_q]
kv_param = param[head_size * n_heads_q:]

View File

@ -122,9 +122,9 @@ class DSSequenceDescriptor(BaseSequenceDescriptor):
self._seen_tokens = 0
self._in_flight_tokens = 0
assert kv_cache_ids_shadow is not None # add check before use
self._num_allocation_groups = tuple(kv_cache_ids_shadow.shape[0]
for kv_cache_ids_shadow in kv_cache_ids_shadow)
self._num_allocation_groups = tuple(kv_cache_id.shape[0] for kv_cache_id in kv_cache_ids_shadow)
self._blocks_per_allocation_group = tuple(
torch.zeros(num_groups, dtype=torch.int32, device="cpu") for num_groups in self._num_allocation_groups)

View File

@ -73,6 +73,8 @@ class MegatronLayerPolicy(TransformerPolicy):
attention = self.client_module.attention
else:
attention = self.client_module.self_attention
else:
return None
return attention.query_key_value.weight, \
attention.query_key_value.bias, \

View File

@ -93,8 +93,10 @@ def generic_injection(module, dtype=None, enable_cuda_graph=True):
return child
if len(policy_attn) == 5:
qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
qw, kw, vw = torch.empty(0), torch.empty(0), torch.empty(0)
else:
qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
qkvw = torch.empty(0)
config = transformer_inference.DeepSpeedInferenceConfig(
hidden_size=hidden_size,
@ -113,11 +115,15 @@ def generic_injection(module, dtype=None, enable_cuda_graph=True):
return data
if len(policy_attn) == 5:
assert qkvw is not None and qkvw.data is not None, "qkvw can't be None"
attn_module.attn_qkvw.data = transpose(qkvw.data)
else:
attn_module.attn_qkvw = None
assert qw is not None and qw.data is not None, "qw can't be None"
attn_module.attn_qw.data = transpose(qw.data)
assert kw is not None and kw.data is not None, "kw can't be None"
attn_module.attn_kw.data = transpose(kw.data)
assert vw is not None and vw.data is not None, "vw can't be None"
attn_module.attn_vw.data = transpose(vw.data)
attn_module.attn_qkvb = None
@ -316,21 +322,15 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
return _autotp._replace_module(module)
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
training = False # todo: refactor this part to go in the config
if training:
# copy relevant state from child -> new module
new_module = replace_with_policy(child, _policy, config.triangular_masking)
# copy relevant state from child -> new module
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
new_module = replace_with_policy(child,
_policy,
config.triangular_masking,
inference=True,
layer_id=layer_id)
else:
# copy relevant state from child -> new module
if not is_autotp_training_mode() and config.replace_with_kernel_inject:
new_module = replace_with_policy(child,
_policy,
config.triangular_masking,
inference=True,
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
return new_module

View File

@ -400,7 +400,7 @@ def topkgating(
me = torch.mean(gates, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts / k
locations = None
if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
@ -437,6 +437,8 @@ def topkgating(
denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps)
gates_masked = gates_masked / denom_s
if locations is None:
raise ValueError(f"Locations is not set: {locations}")
# dispatch_mask
locations_sc = _one_hot_to_float((locations * mask), capacity)

View File

@ -128,18 +128,18 @@ def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stri
inc_b = TK * stride_kb
else:
pinc += 2
if meta['DSD']:
inc_b = tl.load(pinc)
inc_a = tl.load(pinc + 1)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
if meta['DSD']:
inc_b = tl.load(pinc)
inc_a = tl.load(pinc + 1)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = inc_b * stride_kb
if meta['DDS']:
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
pa += inc_a
pb += inc_b
# pre-fetch

View File

@ -57,7 +57,7 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module):
self.attn_2.do_out_bias = False
self.attn_2_bias = self.attn_2.attn_ob
else:
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False)
self.attn_2_bias = nn.Parameter(torch.zeros_like(self.norm3_g), requires_grad=False)
self.gated_activation = GatedActivationOp()
self.layer_norm = LayerNormOp()

View File

@ -335,27 +335,29 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.init_transformer_weights(self.config.adjust_init_range)
else:
# For testing only.
q = initial_weights[0].data
k = initial_weights[1].data
v = initial_weights[2].data
if initial_weights is not None:
# For testing only.
q = initial_weights[0].data
k = initial_weights[1].data
v = initial_weights[2].data
self.attn_qkvw = nn.Parameter(torch.cat((q, k, v)))
#self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
# initial_weights[i].clone()
#torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data)
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
self.attn_qkvb.data.zero_()
self.attn_ow = initial_weights[3]
self.attn_ob = initial_biases[3]
self.attn_nw = initial_weights[4]
self.attn_nb = initial_biases[4]
self.inter_w = initial_weights[5]
self.inter_b = initial_biases[5]
self.output_w = initial_weights[6]
self.output_b = initial_biases[6]
self.norm_w = initial_weights[7]
self.norm_b = initial_biases[7]
self.attn_qkvw = nn.Parameter(torch.cat((q, k, v)))
#self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
# initial_weights[i].clone()
#torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data)
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
self.attn_qkvb.data.zero_()
self.attn_ow = initial_weights[3]
self.attn_nw = initial_weights[4]
self.inter_w = initial_weights[5]
self.output_w = initial_weights[6]
self.norm_w = initial_weights[7]
if initial_biases is not None:
self.attn_ob = initial_biases[3]
self.attn_nb = initial_biases[4]
self.inter_b = initial_biases[5]
self.output_b = initial_biases[6]
self.norm_b = initial_biases[7]
# Load cuda modules if needed
global transformer_cuda_module, stochastic_transformer_cuda_module

View File

@ -715,7 +715,7 @@ class DistributedDataAnalyzer(object):
buffer = torch.cat(tensor_list, dim=0).to(self.device)
write_buffer_to_file(buffer, 0, builder)
elif self.worker_id == 0 and src > 0: # rank 0 receives other rank's data and writes it
buffer = torch.empty(sizes[src].item(), dtype=buffer.dtype, device=buffer.device)
buffer = torch.empty(sizes[src].item(), dtype=numpy_dtype, device=self.device)
err = dist.recv(buffer, src=src, group=self.comm_group, tag=src)
assert err == src and len(buffer) > 0, "recv failed"
write_buffer_to_file(buffer, src, builder)

View File

@ -407,7 +407,6 @@ class DeepSpeedEngine(Module):
for _, module in self.module.named_modules():
if isinstance(module, LoRAOptimizedLinear):
self.optimized_linear_lora_enabled = True
offload_ratio = None
if offload_ratio is not None:
assert offload_ratio == module.lora_config.offload_ratio, \
"all lora_config offload ratios should be the same across the model"
@ -1262,10 +1261,6 @@ class DeepSpeedEngine(Module):
@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
return
if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
def _set_client_model(self, model):
# register client model in _modules so that nn.module methods work correctly

View File

@ -13,7 +13,6 @@ from deepspeed import comm as dist
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
@ -712,7 +711,6 @@ class PipelineEngine(DeepSpeedEngine):
def _exec_forward_pass(self, buffer_id):
self.tput_timer.start()
self.mem_status('BEFORE FWD', reset_max=True)
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
@ -808,13 +806,10 @@ class PipelineEngine(DeepSpeedEngine):
assert self.optimizer is not None, "must provide optimizer during " \
"init in order to use backward"
self.mem_status('BEFORE BWD', reset_max=True)
# The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms.
if self.is_last_stage():
super().backward(self.loss)
self.mem_status('AFTER BWD')
return
outputs = self.pipe_buffers['outputs'][buffer_id]
@ -881,8 +876,6 @@ class PipelineEngine(DeepSpeedEngine):
self.timers(BACKWARD_MICRO_TIMER).stop()
self.timers(BACKWARD_GLOBAL_TIMER).stop()
self.mem_status('AFTER BWD')
def _exec_load_micro_batch(self, buffer_id):
if self.wall_clock_breakdown():
self.timers(BATCH_INPUT_TIMER).start()
@ -1221,14 +1214,11 @@ class PipelineEngine(DeepSpeedEngine):
if self.wall_clock_breakdown():
self.timers(STEP_MICRO_TIMER).start()
self.timers(STEP_GLOBAL_TIMER).start()
self.mem_status('BEFORE STEP', reset_max=True)
self._force_grad_boundary = True
self._take_model_step(lr_kwargs)
self._force_grad_boundary = False
self.mem_status('AFTER STEP')
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
@ -1307,53 +1297,6 @@ class PipelineEngine(DeepSpeedEngine):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
def mem_status(self, msg, print_rank=-1, reset_max=False):
return
global mem_alloced, mem_cached
if not self.global_steps == 0 or not self.global_steps == 9:
#return
pass
if self.mpu.get_data_parallel_rank() != 0:
return
if self.global_rank != 0:
return
rank = self.global_rank
if print_rank != -1 and rank != print_rank:
return
get_accelerator().synchronize()
if reset_max:
get_accelerator().reset_max_memory_cached()
get_accelerator().reset_max_memory_allocated()
new_alloced = get_accelerator().memory_allocated()
new_cached = get_accelerator().memory_cached()
delta_alloced = new_alloced - mem_alloced
delta_cached = new_cached - mem_cached
mem_cached = new_cached
mem_alloced = new_alloced
max_alloced = get_accelerator().max_memory_allocated()
max_cached = get_accelerator().max_memory_cached()
# convert to GB for printing
new_alloced /= 1024**3
new_cached /= 1024**3
delta_alloced /= 1024**3
delta_cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
def module_state_dict(self, exclude_frozen_parameters=False):
"""Override hack to save a pipe model and return the directory path of the save.

View File

@ -15,10 +15,10 @@ class SparseTensor(object):
def __init__(self, dense_tensor=None):
self.orig_dense_tensor = dense_tensor
self.dtype = self.orig_dense_tensor.dtype
self.is_sparse = dense_tensor.is_sparse
if dense_tensor is not None:
if dense_tensor.is_sparse:
self.is_sparse = dense_tensor.is_sparse
self.dtype = self.orig_dense_tensor.dtype
if self.is_sparse:
dense_tensor = dense_tensor.coalesce()
self.indices = dense_tensor.indices().flatten()
self.values = dense_tensor.values()

View File

@ -1656,16 +1656,11 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
if rank is None:
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
global_rank = dist.get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
tensor.copy_(tensor_to_allreduce)
tensor.copy_(tensor_to_allreduce)
return tensor

View File

@ -1346,10 +1346,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
else:
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
assert grad_accum is not None
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
if not self.fp16_master_weights_and_gradients:
src_tensor = src_tensor.float()

View File

@ -124,7 +124,7 @@ class FPDT_InputConstruct(torch.nn.Module):
load_balanced_tokens = self.tokens[:, indices]
load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels
load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else self.attention_mask
load_balanced_attention_mask = self.attention_mask
load_balanced_position_ids = self.position_ids[:,
indices] if self.position_ids is not None else self.position_ids

View File

@ -404,7 +404,7 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_
world_size = dist.get_world_size()
rank = dist.get_rank()
dp_world_size = mpu.get_data_parallel_world_size()
dp_world_size = _get_data_parallel_world_size()
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
_ensure_divisibility(world_size, tensor_parallel_size_)
@ -569,31 +569,37 @@ def _get_data_parallel_group_ranks():
def _get_broadcast_src_rank():
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_global_rank(_get_sequence_data_parallel_group(), 0)
def _get_expert_broadcast_src_rank(group_name):
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
def _get_expert_parallel_world_size(group_name):
"""Return world size for the expert parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_world_size(group=_get_expert_parallel_group(group_name))
def _get_expert_data_parallel_world_size(group_name):
"""Return world size for the expert data parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
def _get_expert_parallel_rank(group_name):
"""Return my rank for the expert parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_rank(group=_get_expert_parallel_group(group_name))
def _get_expert_parallel_src_rank(group_name):
"""Calculate the global rank corresponding to a local rank zero
in the expert parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
global_rank = dist.get_rank()
local_world_size = _get_expert_parallel_world_size(group_name)
return (global_rank // local_world_size) * local_world_size
@ -601,11 +607,13 @@ def _get_expert_parallel_src_rank(group_name):
def _get_expert_data_parallel_rank(group_name):
"""Return my rank for the expert data parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
def _get_data_parallel_world_size():
"""Return world size for the data parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
global mpu
@ -627,11 +635,13 @@ def _get_model_parallel_world_size():
def _get_data_parallel_rank():
"""Return my rank for the data parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
return dist.get_rank(group=_get_data_parallel_group())
def _get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
"""Return world size for the model parallel group."""
assert dist.is_initialized(), 'dist is not initialized'
global mpu
if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))

View File

@ -431,7 +431,6 @@ class OpBuilder(ABC):
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
"only cpu ops can be compiled!")
return '-D__DISABLE_CUDA__'
return '-D__DISABLE_CUDA__'
def _backup_cpuinfo(self):
# Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides

View File

@ -39,7 +39,6 @@ class CCLCommBuilder(CPUOpBuilder):
raise ValueError(
"Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable"
)
return []
else:
return ['-lccl', f'-L{ccl_root_path}/lib']