mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements: - Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time - Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future. - When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well. Better Memory Handling: - Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing). - When we are above peak memory, schedule waits. TODO: - for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory - By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples. - config some hard coded constants, clean up enablement (can do in subsequent pr) On small llama 2d backward : 578 of 618 potentially hideable collectives hidden original mem 14.4GB, rescheduled mem, 15.9GB on forward: 254/256 potentially hideable collectives hidden original mem 5.8 gb, reshceduled mem 5.8GB WIP: adding tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc1f2108d7
commit
b3f6d49b69
@ -34,6 +34,15 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> Optional[object]:
|
||||
if is_all_gather_into_tensor(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter_tensor(node):
|
||||
return _rs_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
|
Reference in New Issue
Block a user