mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] Reduce CPU overhead for reduce_scatter_merge_fn_to_trace (#159723)
The previous implementation was creating `n_gpu * n_tensors` intermediate tensors, which was adding a lot of CPU overhead, specially given that inductor was generating a number of individual tensor copy kernels for `torch.cat` . This PR changes the implementation so that only `n_tensors` are created, making the CPU overhead proportional to the number of tensors being bucketed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159723 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
805a102beb
commit
9a680e14b7
@ -235,36 +235,20 @@ def reduce_scatter_merge_fn_to_trace(
|
||||
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
||||
device: torch.device, # type: ignore[name-defined]
|
||||
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
||||
rs_ins_flattened = [rs_in.view(-1) for rs_in in rs_ins]
|
||||
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
||||
|
||||
rs_ins_srcs = [
|
||||
rs_in_f.split([rs_in_f.numel() // group_size] * group_size)
|
||||
for rs_in_f in rs_ins_flattened
|
||||
]
|
||||
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
||||
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
||||
|
||||
foreach_copy_srcs = []
|
||||
for rank_idx in range(group_size):
|
||||
for rs_in_idx in range(len(rs_ins)):
|
||||
foreach_copy_srcs.append(rs_ins_srcs[rs_in_idx][rank_idx])
|
||||
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
||||
|
||||
new_rs_in = torch.cat(foreach_copy_srcs, dim=0)
|
||||
|
||||
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
||||
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
||||
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
||||
new_rs_in, reduce_op, group_size, group_name
|
||||
)
|
||||
)
|
||||
new_rs_out = wait_tensor
|
||||
|
||||
new_outs = []
|
||||
new_rs_out_offset = 0
|
||||
for rs_in in rs_ins:
|
||||
new_out_size = torch.Size((rs_in.shape[0] // group_size,) + rs_in.shape[1:]) # type: ignore[attr-defined]
|
||||
new_out = new_rs_out.narrow(0, new_rs_out_offset, new_out_size.numel()).reshape(
|
||||
new_out_size
|
||||
)
|
||||
new_outs.append(new_out)
|
||||
new_rs_out_offset += new_out_size.numel()
|
||||
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
||||
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
||||
return new_outs
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user