mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enforce contiguity for alltoall (#141816)
Summary: We cannot relax the alltoall contiguous requirement which will lead to wrong results. Test Plan: Added a test. Differential Revision: D66560930 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141816 Approved by: https://github.com/Skylion007, https://github.com/kwen2501, https://github.com/fduwjj, https://github.com/fegin, https://github.com/yoyoyocmu
This commit is contained in:
committed by
PyTorch MergeBot
parent
eff99a4b4b
commit
61dc5e9c0a
@ -1940,6 +1940,10 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
||||
output_tensor = torch.zeros(2, 2, device=torch.device(device))
|
||||
dist.all_to_all_single(output_tensor, input_tensor)
|
||||
|
||||
input_tensor = input_tensor.t()
|
||||
with self.assertRaisesRegex(ValueError, "Tensors must be contiguous"):
|
||||
dist.all_to_all_single(output_tensor, input_tensor)
|
||||
|
||||
|
||||
class ReduceOpTest(TestCase):
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/87191
|
||||
|
@ -2762,6 +2762,10 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
|
||||
assertDense(invalidArgument, {outputTensor});
|
||||
assertDense(invalidArgument, {inputTensor});
|
||||
|
||||
if (!inputTensor.is_contiguous(inputTensor.suggest_memory_format())) {
|
||||
C10_THROW_ERROR(ValueError, "Tensors must be contiguous");
|
||||
}
|
||||
|
||||
const auto& device = outputTensor.device();
|
||||
c10::intrusive_ptr<AsyncAlltoallWork> work;
|
||||
auto tag = nextTag();
|
||||
|
@ -4494,8 +4494,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
std::vector<int64_t>& outputSplitSizes,
|
||||
std::vector<int64_t>& inputSplitSizes,
|
||||
const AllToAllOptions& /* unused */) {
|
||||
check_gpu_single_tensor(outputTensor, true);
|
||||
check_gpu_single_tensor(inputTensor, true);
|
||||
check_gpu_single_tensor(outputTensor);
|
||||
check_gpu_single_tensor(inputTensor);
|
||||
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
@ -4607,8 +4607,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
|
||||
auto device = outputTensors[0].device();
|
||||
for (const auto r : c10::irange(outputTensors.size())) {
|
||||
check_gpu_single_tensor(outputTensors[r], true);
|
||||
check_gpu_single_tensor(inputTensors[r], true);
|
||||
check_gpu_single_tensor(outputTensors[r]);
|
||||
check_gpu_single_tensor(inputTensors[r]);
|
||||
TORCH_CHECK(
|
||||
device == outputTensors[r].device() &&
|
||||
device == inputTensors[r].device(),
|
||||
|
Reference in New Issue
Block a user