add split sizes info dump for uneven all2all bw calculation (#151438)

Add split sizes info to dumped execution trace and kineto trace for bw calcuation of uneven all2all.

Take input data as an example from case below, although we know input size of Rank-0 is 50 elements, actual data size that Rank-0 sends out is (12+13+14)=39 elements. Rank-0 doesn't send the 1st chunk of 11 elements to peers. But we don't know this infomation now, because "in split size" filed is empty.
![image](https://github.com/user-attachments/assets/7240f334-2081-409b-bbe0-a8396ffa2d30)
![image](https://github.com/user-attachments/assets/679fc49f-e34f-4a74-bad0-fb6fa9d18239)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151438
Approved by: https://github.com/shengfukevin, https://github.com/kwen2501
This commit is contained in:
Sanshan Gao
2025-05-01 01:19:16 +00:00
committed by PyTorch MergeBot
parent 7abca8ceba
commit be1adcae32

View File

@ -5071,6 +5071,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
const AllToAllOptions& opts) {
int64_t input_total_numel = 0;
int64_t output_total_numel = 0;
// considering uneven all2all bw calculation
// use split sizes field to record tensor list sizes
std::vector<int64_t> inSplitSizes;
std::vector<int64_t> outSplitSizes;
auto device = outputTensors[0].device();
for (const auto r : c10::irange(outputTensors.size())) {
@ -5082,6 +5086,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
"Tensors must be on the same device")
input_total_numel += inputTensors[r].numel();
output_total_numel += outputTensors[r].numel();
inSplitSizes.push_back(inputTensors[r].numel());
outSplitSizes.push_back(outputTensors[r].numel());
}
RECORD_PARAM_COMMS_DATA(
@ -5096,8 +5102,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
input_total_numel, // inNelems
output_total_numel, // outNelems
inputTensors.front().scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
inSplitSizes, // inSplitSizes
outSplitSizes, // outSplitSizes
globalRankStart_, // globalRankStart_
globalRankStride_, // globalRankStride_
this->getSize()); // worldSize