mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize aten.cat calls of a repeated element (#132081)
This was a particular problem for a model I saw which would have a large number of repeats, making compilation slow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132081 Approved by: https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8e4060484
commit
baa4c9ca46
@ -7,6 +7,7 @@ import unittest
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
import torch._inductor.decomposition
|
||||
import torch.autograd
|
||||
from torch import Tensor
|
||||
from torch._decomp import core_aten_decompositions, decomposition_table
|
||||
@ -607,6 +608,16 @@ class TestDecomp(TestCase):
|
||||
|
||||
self.assertEqual(xs, xs_two)
|
||||
|
||||
def test_cat_single_input(self, device):
|
||||
decomp_table = torch._inductor.decomposition.select_decomp_table()
|
||||
cat_inductor = decomp_table[torch.ops.aten.cat.default]
|
||||
|
||||
inp = torch.rand([2048, 2048], device=device)
|
||||
inps = [inp for _ in range(10)]
|
||||
|
||||
for dim in (-1, 0, 1):
|
||||
self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
|
||||
|
||||
def test_rrelu_with_noise(self, device):
|
||||
# rrelu_with_noise behavior depends on a) whether elements in the input
|
||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
||||
|
Reference in New Issue
Block a user