MPS: TopK raise an error if K>16 (#79677)

* Error out in TopK when k>16.
* Add a test case too.

Fixes #78915

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79677
Approved by: https://github.com/albanD
This commit is contained in:
Kulin Seth
2022-06-16 16:06:45 +00:00
committed by PyTorch MergeBot
parent 270c518be0
commit 355a1c8c3f
2 changed files with 12 additions and 0 deletions

View File

@ -300,6 +300,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
TORCH_CHECK( k <= 16 , "Currently topk on mps works only for k<=16 ");
if (self.dim() == 0 && self.numel() == 1)
{
values.copy_(self);

View File

@ -2801,6 +2801,16 @@ class TestNLLLoss(TestCase):
helper(3, 3)
def test_assert_topk(self):
# here the k > 16 raises an error as expected
with self.assertRaisesRegex(RuntimeError, "Currently topk on mps works only for k<=16"):
xs = torch.arange(30).to('mps')
xs.topk(30)
# for k <= 16 it works fine
ys_cpu = torch.arange(30)
ys_mps = ys_cpu.to('mps')
self.assertEqual(ys_cpu.topk(16), ys_mps.topk(16))
def test_topk(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)