[ONNX][dynamo_export] Skip instance_norm decomp for export (#120866)

Otherwise, instance_norm is decomposed into batch_norm with training set to True.
Downstream exporter has no way to figure out that training is actually not needed.
On the other hand, ONNX does have InstanceNormalization operator defined, however
due to decomp, it unnecessarily exports as batch norm and glue code.

Depends on https://github.com/microsoft/onnxscript/pull/1284
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120866
Approved by: https://github.com/thiagocrepaldi, https://github.com/titaiwangms
This commit is contained in:
BowenBao
2024-03-01 10:45:15 -08:00
committed by PyTorch MergeBot
parent 581fe26792
commit d8395830ea
4 changed files with 73 additions and 12 deletions

View File

@ -32,8 +32,8 @@ pip_install coloredlogs packaging
pip_install onnxruntime==1.17.0
pip_install onnx==1.15.0
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@1d6362db06706c13447e590ecf5ac3238efc1880" --no-deps
pip_install onnxscript==0.1.0.dev20240216 --no-deps
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps
pip_install onnxscript==0.1.0.dev20240301 --no-deps
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View File

@ -893,16 +893,6 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"),
),
xfail(
"nn.functional.instance_norm",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="fixme: Assertion error: result mismatch",
),
xfail(
"nn.functional.instance_norm",
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
reason="Functionalize pass failed",
),
xfail(
"nn.functional.local_response_norm",
dtypes=(torch.int64,),
@ -1548,6 +1538,13 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
"Reshape", "empty tensor"
),
),
xfail(
"nn.functional.instance_norm",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
matcher=lambda sample: sample.kwargs.get("running_mean") is not None
or sample.input.dtype in (torch.float16,),
reason="fixme: KeyError: 'self___kwargs__running_mean'",
),
xfail(
"nn.functional.max_pool3d",
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
@ -1962,6 +1959,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"nn.functional.hardsigmoid": [1e-3, 5e-3],
"nn.functional.hardswish": [1e-3, 5e-3],
"nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
"nn.functional.instance_norm": [1e-2, 1e-3],
"nn.functional.interpolate": [1e-2, 1e-3],
"nn.functional.kl_div": [2e-3, 2e-4],
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],

View File

@ -39,6 +39,15 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
def test_instance_norm(self):
def func(x: torch.Tensor):
return torch.nn.functional.instance_norm(x)
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
# If decomposition is skipped, the model will contain an InstanceNormalization op
# instead of BatchNormalization op w/ training=True.
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -18,6 +18,7 @@ import contextlib
from typing import Callable, Sequence, Type
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
core as torchlib_core,
nn as torchlib_nn,
)
@ -119,8 +120,61 @@ class UpsampleBilinear2DDecompSkip(DecompSkip):
)
class InstanceNormDecompSkip(DecompSkip):
op_callable = torch.instance_norm # type: ignore[attr-defined]
onnxscript_function = torchlib_core.aten_instance_norm # type: ignore[attr-defined]
new_op_name = "instance_norm"
new_op_schema = (
"(Tensor input, Tensor? weight, Tensor? bias, "
"Tensor? running_mean, Tensor? running_var, "
"bool use_input_stats, float momentum, float eps, "
"bool cudnn_enabled) -> Tensor"
)
@classmethod
def register(cls, export_options: torch.onnx.ExportOptions):
if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
torch.ops.onnx_export, cls.new_op_name
):
cls.register_custom_op()
torch.instance_norm = torch.ops.onnx_export.instance_norm # type: ignore[attr-defined]
if export_options.onnx_registry is None:
export_options.onnx_registry = torch.onnx.OnnxRegistry()
registry = export_options.onnx_registry
registry.register_op(
function=cls.onnxscript_function,
namespace=_NEW_OP_NAMESPACE,
op_name=cls.new_op_name,
)
@classmethod
def unregister(cls):
torch.instance_norm = cls.op_callable # type: ignore[attr-defined]
@classmethod
def abstract(
cls,
input,
weight,
bias,
running_mean,
running_var,
use_input_stats: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
):
return torch.empty(
input.size(),
dtype=input.dtype,
device=input.device,
)
_DEFAULT_SKIP_LIST = [
UpsampleBilinear2DDecompSkip,
InstanceNormDecompSkip,
]