Files
pytorch/torch/_export
Jiashen Cao 686b7f046a [Fix]: TSConverter handles call ops with multiple outputs (#129294)
#### Issue
* Current call ops does not handle IR with multiple outputs. If an op has multiple outputs, we add an implicit unpack to map output. E.g.,
```
%5 : Tensor, %6 : Tensor = aten::max(%x.1, %3, %4), scope: export.test_converter.M:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:774:20
```
* There are some cases that `prim::If` sub-blocks do not return any outputs. E.g.,
```
%9 : bool = aten::gt(%8, %3), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:9
   = prim::If(%9), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2
    block0():
      -> ()
    block1():
       = prim::RaiseException(%5, %4), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2
      -> ()
```

#### Test Plan
We did an exhaustive search of all torch APIs that can return multiple outputs. We sample some of common ones and add new test cases based on those.
* `pytest test/export/test_converter.py -s -k test_ts2ep_multi_outputs_on_call_ops`

#### Appendix
* aten ops that return multiple outputs.
```
aten._batch_norm_impl_index
aten._batch_norm_no_update
aten._batch_norm_with_update
aten._batch_norm_with_update_functional
aten._cudnn_rnn
aten._efficient_attention_backward
aten._efficient_attention_forward
aten._embedding_bag
aten._embedding_bag_forward_only
aten._flash_attention_backward
aten._flash_attention_forward
aten._fused_adam
aten._fused_dropout
aten._fused_moving_avg_obs_fq_helper
aten._linalg_det
aten._linalg_eigh
aten._linalg_slogdet
aten._linalg_solve_ex
aten._linalg_svd
aten._native_batch_norm_legit
aten._native_batch_norm_legit_functional
aten._native_batch_norm_legit_no_training
aten._pack_padded_sequence
aten._prelu_kernel_backward
aten._scaled_dot_product_efficient_attention
aten._scaled_dot_product_efficient_attention_backward
aten._scaled_dot_product_flash_attention
aten._scaled_dot_product_flash_attention_backward
aten._scaled_dot_product_flash_attention_for_cpu
aten._scaled_dot_product_flash_attention_for_cpu_backward
aten._thnn_fused_lstm_cell
aten._thnn_fused_lstm_cell_backward_impl
aten._unique2
aten._weight_norm_interface
aten.adaptive_max_pool2d
aten.adaptive_max_pool3d
aten.aminmax
aten.batch_norm_backward
aten.convolution_backward
aten.cudnn_batch_norm
aten.cudnn_batch_norm_backward
aten.cummax
aten.cummin
aten.fractional_max_pool2d
aten.frexp
aten.grid_sampler_2d_backward
aten.grid_sampler_3d_backward
aten.gru
aten.linalg_cholesky_ex
aten.linalg_eig
aten.linalg_inv_ex
aten.linalg_ldl_factor_ex
aten.linalg_lu
aten.linalg_lu_factor_ex
aten.linalg_qr
aten.linear_backward
aten.log_sigmoid_forward
aten.lstm
aten.lu_unpack
aten.max
aten.max_pool2d_with_indices
aten.max_pool3d_with_indices
aten.median
aten.min
aten.miopen_batch_norm
aten.miopen_batch_norm_backward
aten.mkldnn_rnn_layer
aten.mkldnn_rnn_layer_backward
aten.mode
aten.multilabel_margin_loss_forward
aten.nanmedian
aten.native_batch_norm
aten.native_batch_norm_backward
aten.native_dropout
aten.native_group_norm
aten.native_group_norm_backward
aten.native_layer_norm
aten.native_layer_norm_backward
aten.nll_loss2d_forward
aten.nll_loss_forward
aten.quantized_gru
aten.quantized_lstm
aten.rnn_relu
aten.rnn_tanh
aten.sort
aten.std_mean
aten.topk
aten.triangular_solve
aten.unique_dim
aten.var_mean
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129294
Approved by: https://github.com/angelayi
2024-07-18 21:55:18 +00:00
..
2024-04-26 15:35:53 +00:00