mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add shape function for stack op (#92205)
As @ramiro050 requested in https://github.com/llvm/torch-mlir/pull/1747, this PR moved the shape code for stack op from torch-mlir to pytorch upstream. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92205 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e73cc9310
commit
3d5eba811a
@ -12,7 +12,7 @@ number = Union[int, float]
|
||||
# After regenerating files, compile PyTorch.
|
||||
# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
|
||||
# If you have enabled opinfo testing for the op, also run:
|
||||
# python test/test_ops_jit.py TestJitCPU::test_variant_consistency_jit_[FAILING_OP]_cpu_float32
|
||||
# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
|
||||
# to reproduce errors from opinfo tests.
|
||||
|
||||
# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
|
||||
@ -545,6 +545,14 @@ def cat(tensors: List[List[int]], dim: int):
|
||||
return result_size
|
||||
|
||||
|
||||
def stack(tensors: List[List[int]], dim: int):
|
||||
unsqueezed_tensors: List[List[int]] = []
|
||||
for tensor in tensors:
|
||||
unsqueezed = unsqueeze(tensor, dim)
|
||||
unsqueezed_tensors.append(unsqueezed)
|
||||
return cat(unsqueezed_tensors, dim)
|
||||
|
||||
|
||||
def select(self: List[int], dim: int, index: int):
|
||||
ndim = len(self)
|
||||
assert ndim != 0
|
||||
@ -1100,6 +1108,7 @@ add_shape_compute_mapping("aten::convolution(Tensor input, Tensor weight, Tensor
|
||||
add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
|
||||
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
|
||||
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
|
||||
add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
|
||||
add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)
|
||||
add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim)
|
||||
add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
|
||||
|
Reference in New Issue
Block a user