mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
#### Issue * Current call ops does not handle IR with multiple outputs. If an op has multiple outputs, we add an implicit unpack to map output. E.g., ``` %5 : Tensor, %6 : Tensor = aten::max(%x.1, %3, %4), scope: export.test_converter.M:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:774:20 ``` * There are some cases that `prim::If` sub-blocks do not return any outputs. E.g., ``` %9 : bool = aten::gt(%8, %3), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:9 = prim::If(%9), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2 block0(): -> () block1(): = prim::RaiseException(%5, %4), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2 -> () ``` #### Test Plan We did an exhaustive search of all torch APIs that can return multiple outputs. We sample some of common ones and add new test cases based on those. * `pytest test/export/test_converter.py -s -k test_ts2ep_multi_outputs_on_call_ops` #### Appendix * aten ops that return multiple outputs. ``` aten._batch_norm_impl_index aten._batch_norm_no_update aten._batch_norm_with_update aten._batch_norm_with_update_functional aten._cudnn_rnn aten._efficient_attention_backward aten._efficient_attention_forward aten._embedding_bag aten._embedding_bag_forward_only aten._flash_attention_backward aten._flash_attention_forward aten._fused_adam aten._fused_dropout aten._fused_moving_avg_obs_fq_helper aten._linalg_det aten._linalg_eigh aten._linalg_slogdet aten._linalg_solve_ex aten._linalg_svd aten._native_batch_norm_legit aten._native_batch_norm_legit_functional aten._native_batch_norm_legit_no_training aten._pack_padded_sequence aten._prelu_kernel_backward aten._scaled_dot_product_efficient_attention aten._scaled_dot_product_efficient_attention_backward aten._scaled_dot_product_flash_attention aten._scaled_dot_product_flash_attention_backward aten._scaled_dot_product_flash_attention_for_cpu aten._scaled_dot_product_flash_attention_for_cpu_backward aten._thnn_fused_lstm_cell aten._thnn_fused_lstm_cell_backward_impl aten._unique2 aten._weight_norm_interface aten.adaptive_max_pool2d aten.adaptive_max_pool3d aten.aminmax aten.batch_norm_backward aten.convolution_backward aten.cudnn_batch_norm aten.cudnn_batch_norm_backward aten.cummax aten.cummin aten.fractional_max_pool2d aten.frexp aten.grid_sampler_2d_backward aten.grid_sampler_3d_backward aten.gru aten.linalg_cholesky_ex aten.linalg_eig aten.linalg_inv_ex aten.linalg_ldl_factor_ex aten.linalg_lu aten.linalg_lu_factor_ex aten.linalg_qr aten.linear_backward aten.log_sigmoid_forward aten.lstm aten.lu_unpack aten.max aten.max_pool2d_with_indices aten.max_pool3d_with_indices aten.median aten.min aten.miopen_batch_norm aten.miopen_batch_norm_backward aten.mkldnn_rnn_layer aten.mkldnn_rnn_layer_backward aten.mode aten.multilabel_margin_loss_forward aten.nanmedian aten.native_batch_norm aten.native_batch_norm_backward aten.native_dropout aten.native_group_norm aten.native_group_norm_backward aten.native_layer_norm aten.native_layer_norm_backward aten.nll_loss2d_forward aten.nll_loss_forward aten.quantized_gru aten.quantized_lstm aten.rnn_relu aten.rnn_tanh aten.sort aten.std_mean aten.topk aten.triangular_solve aten.unique_dim aten.var_mean ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129294 Approved by: https://github.com/angelayi