use more elements per thread for narrow dtypes (#139449)

Fix perf issue for narrow type by accessing more elements per thread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139449
Approved by: https://github.com/Chillee, https://github.com/eqy
This commit is contained in:
Natalia Gimelshein
2024-11-04 16:43:33 +00:00
committed by PyTorch MergeBot
parent 3ca794783f
commit d3fc13a9dd
5 changed files with 67 additions and 27 deletions

View File

@ -1045,7 +1045,6 @@ class TestReductions(TestCase):
a[:, (shape[1] - 1) // 2:] = True
values, indices = a.mode(-1)
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
print(indices)
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
self.assertEqual(values, indexed)