[bucketing] Reduce CPU overhead for reduce_scatter_merge_fn_to_trace (#159723)

The previous implementation was creating `n_gpu * n_tensors` intermediate tensors, which was adding a lot of CPU overhead, specially given that inductor was generating a number of individual tensor copy kernels for `torch.cat` .

This PR changes the implementation so that only `n_tensors` are created, making the CPU overhead proportional to the number of tensors being bucketed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159723
Approved by: https://github.com/IvanKobzarev
This commit is contained in:
Francisco Massa
2025-08-03 09:16:55 +00:00
committed by PyTorch MergeBot
parent 805a102beb
commit 9a680e14b7

View File

@ -235,36 +235,20 @@ def reduce_scatter_merge_fn_to_trace(
reduce_dtype: torch.dtype, # type: ignore[name-defined]
device: torch.device, # type: ignore[name-defined]
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
rs_ins_flattened = [rs_in.view(-1) for rs_in in rs_ins]
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
rs_ins_srcs = [
rs_in_f.split([rs_in_f.numel() // group_size] * group_size)
for rs_in_f in rs_ins_flattened
]
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
new_out_numels = [x.numel() // group_size for x in rs_ins]
foreach_copy_srcs = []
for rank_idx in range(group_size):
for rs_in_idx in range(len(rs_ins)):
foreach_copy_srcs.append(rs_ins_srcs[rs_in_idx][rank_idx])
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
new_rs_in = torch.cat(foreach_copy_srcs, dim=0)
wait_tensor = torch.ops.c10d_functional.wait_tensor(
new_rs_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
new_rs_in, reduce_op, group_size, group_name
)
)
new_rs_out = wait_tensor
new_outs = []
new_rs_out_offset = 0
for rs_in in rs_ins:
new_out_size = torch.Size((rs_in.shape[0] // group_size,) + rs_in.shape[1:]) # type: ignore[attr-defined]
new_out = new_rs_out.narrow(0, new_rs_out_offset, new_out_size.numel()).reshape(
new_out_size
)
new_outs.append(new_out)
new_rs_out_offset += new_out_size.numel()
new_out_flat = new_rs_out.split(new_out_numels, 0)
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
return new_outs