mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements: - Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time - Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future. - When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well. Better Memory Handling: - Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing). - When we are above peak memory, schedule waits. TODO: - for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory - By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples. - config some hard coded constants, clean up enablement (can do in subsequent pr) On small llama 2d backward : 578 of 618 potentially hideable collectives hidden original mem 14.4GB, rescheduled mem, 15.9GB on forward: 254/256 potentially hideable collectives hidden original mem 5.8 gb, reshceduled mem 5.8GB WIP: adding tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc1f2108d7
commit
b3f6d49b69
@ -356,7 +356,9 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
return size
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
|
||||
def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node, override_size: Optional[int] = None
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -371,7 +373,10 @@ def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> flo
|
||||
"""
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
|
||||
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
|
||||
if override_size is None:
|
||||
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
|
||||
else:
|
||||
tensor_storage_size_bytes = override_size
|
||||
|
||||
assert not isinstance(fx_node.target, str)
|
||||
opt_args_kwargs = normalize_function(
|
||||
|
Reference in New Issue
Block a user