Optimize aten.cat calls of a repeated element (#132081)

This was a particular problem for a model I saw which would have a large number of repeats, making compilation slow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132081
Approved by: https://github.com/shunting314
This commit is contained in:
eellison
2024-07-29 16:47:18 -07:00
committed by PyTorch MergeBot
parent f8e4060484
commit baa4c9ca46
3 changed files with 23 additions and 1 deletions

View File

@ -7,6 +7,7 @@ import unittest
from collections import defaultdict
from functools import partial
import torch._inductor.decomposition
import torch.autograd
from torch import Tensor
from torch._decomp import core_aten_decompositions, decomposition_table
@ -607,6 +608,16 @@ class TestDecomp(TestCase):
self.assertEqual(xs, xs_two)
def test_cat_single_input(self, device):
decomp_table = torch._inductor.decomposition.select_decomp_table()
cat_inductor = decomp_table[torch.ops.aten.cat.default]
inp = torch.rand([2048, 2048], device=device)
inps = [inp for _ in range(10)]
for dim in (-1, 0, 1):
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
def test_rrelu_with_noise(self, device):
# rrelu_with_noise behavior depends on a) whether elements in the input
# are <= 0, and b) whether we're in training mode. Cover all cases: