[ONNX] Fix upsample_bilinear2d decomp skip with output shape (#118823)

The previous output size missed the first two dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118823
Approved by: https://github.com/titaiwangms
This commit is contained in:
BowenBao
2024-02-01 22:04:35 +00:00
committed by PyTorch MergeBot
parent 6692f2c91e
commit d4a94ad041
2 changed files with 14 additions and 2 deletions

View File

@ -31,6 +31,14 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
def test_upsample_bilinear2d_output_size(self):
def func(x: torch.Tensor):
return torch.nn.functional.interpolate(x, size=(4, 4), mode="bilinear")
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -28,7 +28,7 @@ _NEW_OP_NAMESPACE: str = "onnx_export"
"""The namespace for the custom operator."""
class DecompSkip:
class DecompSkip(abc.ABC):
op_callable: Callable
"""The original operator callable to skip decomposition."""
onnxscript_function: Callable
@ -112,7 +112,11 @@ class UpsampleBilinear2DDecompSkip(DecompSkip):
osize = decompositions.upsample_compute_output_size(
input.size(), output_size, scale_factors
)
return torch.empty(osize, dtype=input.dtype, device=input.device)
return torch.empty(
(input.size(0), input.size(1), *osize),
dtype=input.dtype,
device=input.device,
)
_DEFAULT_SKIP_LIST = [