mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Fix upsample_bilinear2d decomp skip with output shape (#118823)
The previous output size missed the first two dimensions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118823 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
6692f2c91e
commit
d4a94ad041
@ -31,6 +31,14 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_upsample_bilinear2d_output_size(self):
|
||||
def func(x: torch.Tensor):
|
||||
return torch.nn.functional.interpolate(x, size=(4, 4), mode="bilinear")
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -28,7 +28,7 @@ _NEW_OP_NAMESPACE: str = "onnx_export"
|
||||
"""The namespace for the custom operator."""
|
||||
|
||||
|
||||
class DecompSkip:
|
||||
class DecompSkip(abc.ABC):
|
||||
op_callable: Callable
|
||||
"""The original operator callable to skip decomposition."""
|
||||
onnxscript_function: Callable
|
||||
@ -112,7 +112,11 @@ class UpsampleBilinear2DDecompSkip(DecompSkip):
|
||||
osize = decompositions.upsample_compute_output_size(
|
||||
input.size(), output_size, scale_factors
|
||||
)
|
||||
return torch.empty(osize, dtype=input.dtype, device=input.device)
|
||||
return torch.empty(
|
||||
(input.size(0), input.size(1), *osize),
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SKIP_LIST = [
|
||||
|
Reference in New Issue
Block a user