mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add split sizes info dump for uneven all2all bw calculation (#151438)
Add split sizes info to dumped execution trace and kineto trace for bw calcuation of uneven all2all. Take input data as an example from case below, although we know input size of Rank-0 is 50 elements, actual data size that Rank-0 sends out is (12+13+14)=39 elements. Rank-0 doesn't send the 1st chunk of 11 elements to peers. But we don't know this infomation now, because "in split size" filed is empty.   Pull Request resolved: https://github.com/pytorch/pytorch/pull/151438 Approved by: https://github.com/shengfukevin, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
7abca8ceba
commit
be1adcae32
@ -5071,6 +5071,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
const AllToAllOptions& opts) {
|
||||
int64_t input_total_numel = 0;
|
||||
int64_t output_total_numel = 0;
|
||||
// considering uneven all2all bw calculation
|
||||
// use split sizes field to record tensor list sizes
|
||||
std::vector<int64_t> inSplitSizes;
|
||||
std::vector<int64_t> outSplitSizes;
|
||||
|
||||
auto device = outputTensors[0].device();
|
||||
for (const auto r : c10::irange(outputTensors.size())) {
|
||||
@ -5082,6 +5086,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
"Tensors must be on the same device")
|
||||
input_total_numel += inputTensors[r].numel();
|
||||
output_total_numel += outputTensors[r].numel();
|
||||
inSplitSizes.push_back(inputTensors[r].numel());
|
||||
outSplitSizes.push_back(outputTensors[r].numel());
|
||||
}
|
||||
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
@ -5096,8 +5102,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
input_total_numel, // inNelems
|
||||
output_total_numel, // outNelems
|
||||
inputTensors.front().scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
inSplitSizes, // inSplitSizes
|
||||
outSplitSizes, // outSplitSizes
|
||||
globalRankStart_, // globalRankStart_
|
||||
globalRankStride_, // globalRankStride_
|
||||
this->getSize()); // worldSize
|
||||
|
Reference in New Issue
Block a user