Add shape function for stack op (#92205)

As @ramiro050 requested in https://github.com/llvm/torch-mlir/pull/1747, this PR moved the shape code for stack op from torch-mlir to pytorch upstream.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92205
Approved by: https://github.com/eellison
This commit is contained in:
lijiahao
2023-03-07 20:45:56 +00:00
committed by PyTorch MergeBot
parent 5e73cc9310
commit 3d5eba811a
2 changed files with 159 additions and 1 deletions

View File

@ -12,7 +12,7 @@ number = Union[int, float]
# After regenerating files, compile PyTorch.
# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
# If you have enabled opinfo testing for the op, also run:
# python test/test_ops_jit.py TestJitCPU::test_variant_consistency_jit_[FAILING_OP]_cpu_float32
# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
# to reproduce errors from opinfo tests.
# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
@ -545,6 +545,14 @@ def cat(tensors: List[List[int]], dim: int):
return result_size
def stack(tensors: List[List[int]], dim: int):
unsqueezed_tensors: List[List[int]] = []
for tensor in tensors:
unsqueezed = unsqueeze(tensor, dim)
unsqueezed_tensors.append(unsqueezed)
return cat(unsqueezed_tensors, dim)
def select(self: List[int], dim: int, index: int):
ndim = len(self)
assert ndim != 0
@ -1100,6 +1108,7 @@ add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor
add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)
add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim)
add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)