[Misc] enhance type hint for rearrange return value (#23519)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@ -409,12 +409,14 @@ class EplbState:
|
||||
self.expert_rearrangement_step = 0
|
||||
self.rearrange(model)
|
||||
|
||||
def rearrange(self,
|
||||
model: MixtureOfExperts,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int, int]] = None) -> None:
|
||||
def rearrange(
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int,
|
||||
int]] = None) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Rearrange the experts according to the current load.
|
||||
"""
|
||||
@ -548,6 +550,7 @@ class EplbState:
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
return next_node_id
|
||||
|
Reference in New Issue
Block a user