diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index dfaf3f8d8921..75dd3678d51c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -235,36 +235,20 @@ def reduce_scatter_merge_fn_to_trace( reduce_dtype: torch.dtype, # type: ignore[name-defined] device: torch.device, # type: ignore[name-defined] ) -> list[torch.Tensor]: # type: ignore[no-untyped-def] - rs_ins_flattened = [rs_in.view(-1) for rs_in in rs_ins] + rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins] - rs_ins_srcs = [ - rs_in_f.split([rs_in_f.numel() // group_size] * group_size) - for rs_in_f in rs_ins_flattened - ] + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + new_out_numels = [x.numel() // group_size for x in rs_ins] - foreach_copy_srcs = [] - for rank_idx in range(group_size): - for rs_in_idx in range(len(rs_ins)): - foreach_copy_srcs.append(rs_ins_srcs[rs_in_idx][rank_idx]) + new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() - new_rs_in = torch.cat(foreach_copy_srcs, dim=0) - - wait_tensor = torch.ops.c10d_functional.wait_tensor( + new_rs_out = torch.ops.c10d_functional.wait_tensor( torch.ops._c10d_functional.reduce_scatter_tensor.default( new_rs_in, reduce_op, group_size, group_name ) ) - new_rs_out = wait_tensor - - new_outs = [] - new_rs_out_offset = 0 - for rs_in in rs_ins: - new_out_size = torch.Size((rs_in.shape[0] // group_size,) + rs_in.shape[1:]) # type: ignore[attr-defined] - new_out = new_rs_out.narrow(0, new_rs_out_offset, new_out_size.numel()).reshape( - new_out_size - ) - new_outs.append(new_out) - new_rs_out_offset += new_out_size.numel() + new_out_flat = new_rs_out.split(new_out_numels, 0) + new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)] return new_outs