Fixes CUDA vs CPU consistency for index_put_ when accumulating (part 2) (#67189)

Summary:
Description:
- Follow up PR to https://github.com/pytorch/pytorch/issues/66790 to fix the tests on functorch, https://github.com/pytorch/functorch/issues/195

In functorch, a null tensor is added to the list of indices for the batch dimension in C++, but I can not find an equivalent of that in python without using `torch.jit.script`. If any other better solutions could be suggested, I'd be happy to replace the current way of testing.

cc ngimel zou3519

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67189

Reviewed By: suo

Differential Revision: D31966686

Pulled By: ngimel

fbshipit-source-id: a14b9e5d77d9f43cd728d474e2976d84a87a6ff4
This commit is contained in:
vfdev-5
2021-11-08 17:55:03 -08:00
committed by Facebook GitHub Bot
parent 3f048c637f
commit a2ab06514b
2 changed files with 59 additions and 3 deletions

View File

@ -209,6 +209,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
if (indices.size() > (size_t)self.dim()) {
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
if (!self.is_contiguous()) {
self = self.contiguous();
}
Tensor linearIndex, src, expandedValue = value;
int64_t nElemBefore, strideBefore, sliceSize;
std::vector<int64_t> inversePerm;
@ -216,7 +219,15 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
int64_t num_indices = linearIndex.numel();
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
auto expanded_size = infer_size_dimvector(expandedValue.sizes(), linearIndex.sizes());
auto expanded_size = at::DimVector(expandedValue.sizes());
auto size1 = expandedValue.sizes();
auto size2 = linearIndex.sizes();
if (are_expandable(size1, size2)) {
expanded_size = infer_size_dimvector(size1, size2);
}
if (nElemBefore > 1) {
expanded_size.insert(expanded_size.begin(), nElemBefore);
}
expandedValue = expandedValue.expand(expanded_size);
}
expandedValue = expandedValue.contiguous();
@ -277,8 +288,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
if (permuted)
self.copy_(src_.permute(inversePerm));
if (permuted) {
self.copy_(src_.permute(inversePerm));
}
}
}

View File

@ -804,6 +804,50 @@ class TestIndexing(TestCase):
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2))
t_dev = t.to(device)
t1 = t_dev[:, 0, :]
t2 = t[:, 0, :]
self.assertTrue(not t1.is_contiguous())
self.assertTrue(not t2.is_contiguous())
indices = [torch.tensor([0, 1]), ]
indices_dev = [i.to(device) for i in indices]
value = torch.randn(2, 2)
out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
out_cpu = t2.index_put_(indices, value, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
def test_index_put_accumulate_with_optional_tensors(self, device):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
# on C++ it gives indices as a list of 2 optional tensors: first is null and
# the second is a valid tensor.
@torch.jit.script
def func(x, i, v):
idx = [None, i]
x.index_put_(idx, v, accumulate=True)
return x
n = 4
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)
indices = torch.tensor([1, 0])
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
out_cuda = func(t_dev, indices_dev, value0d.cuda())
out_cpu = func(t, indices, value0d)
self.assertEqual(out_cuda.cpu(), out_cpu)
out_cuda = func(t_dev, indices_dev, value1d.cuda())
out_cpu = func(t, indices, value1d)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyNativeDeviceTypes
def test_index_put_accumulate_duplicate_indices(self, device):
for i in range(1, 512):