mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
MPS: TopK raise an error if K>16 (#79677)
* Error out in TopK when k>16. * Add a test case too. Fixes #78915 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79677 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
270c518be0
commit
355a1c8c3f
@ -300,6 +300,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
||||
"selected index k out of range");
|
||||
|
||||
TORCH_CHECK( k <= 16 , "Currently topk on mps works only for k<=16 ");
|
||||
|
||||
if (self.dim() == 0 && self.numel() == 1)
|
||||
{
|
||||
values.copy_(self);
|
||||
|
@ -2801,6 +2801,16 @@ class TestNLLLoss(TestCase):
|
||||
|
||||
helper(3, 3)
|
||||
|
||||
def test_assert_topk(self):
|
||||
# here the k > 16 raises an error as expected
|
||||
with self.assertRaisesRegex(RuntimeError, "Currently topk on mps works only for k<=16"):
|
||||
xs = torch.arange(30).to('mps')
|
||||
xs.topk(30)
|
||||
# for k <= 16 it works fine
|
||||
ys_cpu = torch.arange(30)
|
||||
ys_mps = ys_cpu.to('mps')
|
||||
self.assertEqual(ys_cpu.topk(16), ys_mps.topk(16))
|
||||
|
||||
def test_topk(self):
|
||||
def helper(shape):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
|
||||
|
Reference in New Issue
Block a user