mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes CUDA vs CPU consistency for index_put_ when accumulating (part 2) (#67189)
Summary: Description: - Follow up PR to https://github.com/pytorch/pytorch/issues/66790 to fix the tests on functorch, https://github.com/pytorch/functorch/issues/195 In functorch, a null tensor is added to the list of indices for the batch dimension in C++, but I can not find an equivalent of that in python without using `torch.jit.script`. If any other better solutions could be suggested, I'd be happy to replace the current way of testing. cc ngimel zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/67189 Reviewed By: suo Differential Revision: D31966686 Pulled By: ngimel fbshipit-source-id: a14b9e5d77d9f43cd728d474e2976d84a87a6ff4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3f048c637f
commit
a2ab06514b
@ -209,6 +209,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
|
||||
if (indices.size() > (size_t)self.dim()) {
|
||||
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
}
|
||||
if (!self.is_contiguous()) {
|
||||
self = self.contiguous();
|
||||
}
|
||||
Tensor linearIndex, src, expandedValue = value;
|
||||
int64_t nElemBefore, strideBefore, sliceSize;
|
||||
std::vector<int64_t> inversePerm;
|
||||
@ -216,7 +219,15 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
|
||||
int64_t num_indices = linearIndex.numel();
|
||||
|
||||
if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
|
||||
auto expanded_size = infer_size_dimvector(expandedValue.sizes(), linearIndex.sizes());
|
||||
auto expanded_size = at::DimVector(expandedValue.sizes());
|
||||
auto size1 = expandedValue.sizes();
|
||||
auto size2 = linearIndex.sizes();
|
||||
if (are_expandable(size1, size2)) {
|
||||
expanded_size = infer_size_dimvector(size1, size2);
|
||||
}
|
||||
if (nElemBefore > 1) {
|
||||
expanded_size.insert(expanded_size.begin(), nElemBefore);
|
||||
}
|
||||
expandedValue = expandedValue.expand(expanded_size);
|
||||
}
|
||||
expandedValue = expandedValue.contiguous();
|
||||
@ -277,8 +288,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
|
||||
if (permuted)
|
||||
self.copy_(src_.permute(inversePerm));
|
||||
if (permuted) {
|
||||
self.copy_(src_.permute(inversePerm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -804,6 +804,50 @@ class TestIndexing(TestCase):
|
||||
out_cpu = t.index_put_(indices, values2d, accumulate=True)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
def test_index_put_accumulate_non_contiguous(self, device):
|
||||
t = torch.zeros((5, 2, 2))
|
||||
t_dev = t.to(device)
|
||||
t1 = t_dev[:, 0, :]
|
||||
t2 = t[:, 0, :]
|
||||
self.assertTrue(not t1.is_contiguous())
|
||||
self.assertTrue(not t2.is_contiguous())
|
||||
|
||||
indices = [torch.tensor([0, 1]), ]
|
||||
indices_dev = [i.to(device) for i in indices]
|
||||
value = torch.randn(2, 2)
|
||||
out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
|
||||
out_cpu = t2.index_put_(indices, value, accumulate=True)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
def test_index_put_accumulate_with_optional_tensors(self, device):
|
||||
# TODO: replace with a better solution.
|
||||
# Currently, here using torchscript to put None into indices.
|
||||
# on C++ it gives indices as a list of 2 optional tensors: first is null and
|
||||
# the second is a valid tensor.
|
||||
@torch.jit.script
|
||||
def func(x, i, v):
|
||||
idx = [None, i]
|
||||
x.index_put_(idx, v, accumulate=True)
|
||||
return x
|
||||
|
||||
n = 4
|
||||
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
|
||||
t_dev = t.to(device)
|
||||
indices = torch.tensor([1, 0])
|
||||
indices_dev = indices.to(device)
|
||||
value0d = torch.tensor(10.0)
|
||||
value1d = torch.tensor([1.0, 2.0])
|
||||
|
||||
out_cuda = func(t_dev, indices_dev, value0d.cuda())
|
||||
out_cpu = func(t, indices, value0d)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
out_cuda = func(t_dev, indices_dev, value1d.cuda())
|
||||
out_cpu = func(t, indices, value1d)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_index_put_accumulate_duplicate_indices(self, device):
|
||||
for i in range(1, 512):
|
||||
|
Reference in New Issue
Block a user