[bucketing] Use max of input/output size for bucketing (#159717)

The output of a reduce_scatter is n_gpu times smaller than its input, while the output of an all_gather is n_gpu times larger than its input. This means that in the current heuristic for bucketing reduce_scatter, we would need to use a bucket size which is n_gpu times larger than the bucket for all_gather, making it gpu-dependent and less intuitive. This PRs propose to use instead the max between the input and output sizes, so that one can use the same bucket_size value for both passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159717
Approved by: https://github.com/wconstab
This commit is contained in:
Francisco Massa
2025-08-02 22:42:19 +00:00
committed by PyTorch MergeBot
parent be71000ff5
commit d2792f51b2

View File

@ -133,10 +133,10 @@ def greedy_bucket_collective_by_mb(
assert "val" in node.meta assert "val" in node.meta
n_val = node.meta["val"] n_val = node.meta["val"]
out_size_bytes = n_val.numel() * n_val.element_size() out_size_bytes = n_val.numel() * n_val.element_size()
if ( n_input_val = node.all_input_nodes[0].meta["val"]
cur_bucket_size_bytes + out_size_bytes > bucket_size_bytes in_size_bytes = n_input_val.numel() * n_input_val.element_size()
and cur_bucket size_bytes = max(out_size_bytes, in_size_bytes)
): if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
# Current bucket is full, create new bucket # Current bucket is full, create new bucket
if len(cur_bucket) > 1: if len(cur_bucket) > 1:
buckets.append(cur_bucket) buckets.append(cur_bucket)
@ -144,7 +144,7 @@ def greedy_bucket_collective_by_mb(
cur_bucket_size_bytes = 0 cur_bucket_size_bytes = 0
cur_bucket_id += 1 cur_bucket_id += 1
cur_bucket_successors = OrderedSet() cur_bucket_successors = OrderedSet()
cur_bucket_size_bytes += out_size_bytes cur_bucket_size_bytes += size_bytes
cur_bucket.append(node) cur_bucket.append(node)
cur_bucket_successors |= nodes_successors[node] cur_bucket_successors |= nodes_successors[node]
if len(cur_bucket) > 1: if len(cur_bucket) > 1: