Enforce contiguity for alltoall (#141816)

Summary: We cannot relax the alltoall contiguous requirement which will lead to wrong results.

Test Plan: Added a test.

Differential Revision: D66560930

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141816
Approved by: https://github.com/Skylion007, https://github.com/kwen2501, https://github.com/fduwjj, https://github.com/fegin, https://github.com/yoyoyocmu
This commit is contained in:
Xiaodong Wang
2024-12-04 10:17:39 +00:00
committed by PyTorch MergeBot
parent eff99a4b4b
commit 61dc5e9c0a
3 changed files with 12 additions and 4 deletions

View File

@ -1940,6 +1940,10 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
output_tensor = torch.zeros(2, 2, device=torch.device(device))
dist.all_to_all_single(output_tensor, input_tensor)
input_tensor = input_tensor.t()
with self.assertRaisesRegex(ValueError, "Tensors must be contiguous"):
dist.all_to_all_single(output_tensor, input_tensor)
class ReduceOpTest(TestCase):
# Ref: https://github.com/pytorch/pytorch/issues/87191

View File

@ -2762,6 +2762,10 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
assertDense(invalidArgument, {outputTensor});
assertDense(invalidArgument, {inputTensor});
if (!inputTensor.is_contiguous(inputTensor.suggest_memory_format())) {
C10_THROW_ERROR(ValueError, "Tensors must be contiguous");
}
const auto& device = outputTensor.device();
c10::intrusive_ptr<AsyncAlltoallWork> work;
auto tag = nextTag();

View File

@ -4494,8 +4494,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& /* unused */) {
check_gpu_single_tensor(outputTensor, true);
check_gpu_single_tensor(inputTensor, true);
check_gpu_single_tensor(outputTensor);
check_gpu_single_tensor(inputTensor);
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
@ -4607,8 +4607,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
auto device = outputTensors[0].device();
for (const auto r : c10::irange(outputTensors.size())) {
check_gpu_single_tensor(outputTensors[r], true);
check_gpu_single_tensor(inputTensors[r], true);
check_gpu_single_tensor(outputTensors[r]);
check_gpu_single_tensor(inputTensors[r]);
TORCH_CHECK(
device == outputTensors[r].device() &&
device == inputTensors[r].device(),