Allow int vals to go down the fastpath for _foreach_max (#127303)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127303
Approved by: https://github.com/albanD
ghstack dependencies: #127187
This commit is contained in:
Jane Xu
2024-05-29 08:18:19 -07:00
committed by PyTorch MergeBot
parent 601c5e085d
commit 05e99154ee
4 changed files with 10 additions and 19 deletions

View File

@ -1015,7 +1015,7 @@ class TestForeach(TestCase):
def test_foreach_reduce_large_input(self, device, dtype, op):
# test inputs larger than kChunkSize = 65536
N = 65536 * 2
disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool)
disable_fastpath = False
kwargs = {}
if op.name == "_foreach_norm":
ord = 2