[ROCm] Improve backwards indexing when stride is not one (#147630)

Improve backwards indexing when stride is not one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147630
Approved by: https://github.com/jeffdaily
This commit is contained in:
Doru Bercea
2025-03-11 19:02:44 +00:00
committed by PyTorch MergeBot
parent daff65d671
commit a1cb67b69e
2 changed files with 230 additions and 235 deletions

View File

@ -992,6 +992,7 @@ class TestIndexing(TestCase):
num_indices = 401988
max_index_range = 2000
target_index_range = [16, 256, 2000]
# BFloat16
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
@ -1010,6 +1011,27 @@ class TestIndexing(TestCase):
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
# Float32
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
a = torch.randn(a_tensor_size, dtype=torch.float32)
b = generate_indices(
num_indices=num_indices, index_range=generated_index_range
)
c_tensor_size = (num_indices, 256)
c = torch.randn(c_tensor_size, dtype=torch.float32)
# create GPU copies
a_dev = a.to(device)
b_dev = b.to(device)
c_dev = c.to(device)
# run
torch.use_deterministic_algorithms(True)
a.index_put_(indices=[b], values=c, accumulate=True)
torch.use_deterministic_algorithms(False)
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
@onlyCUDA
def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2))