mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX][dynamo_export] Skip instance_norm decomp for export (#120866)
Otherwise, instance_norm is decomposed into batch_norm with training set to True. Downstream exporter has no way to figure out that training is actually not needed. On the other hand, ONNX does have InstanceNormalization operator defined, however due to decomp, it unnecessarily exports as batch norm and glue code. Depends on https://github.com/microsoft/onnxscript/pull/1284 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120866 Approved by: https://github.com/thiagocrepaldi, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
581fe26792
commit
d8395830ea
@ -32,8 +32,8 @@ pip_install coloredlogs packaging
|
||||
|
||||
pip_install onnxruntime==1.17.0
|
||||
pip_install onnx==1.15.0
|
||||
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@1d6362db06706c13447e590ecf5ac3238efc1880" --no-deps
|
||||
pip_install onnxscript==0.1.0.dev20240216 --no-deps
|
||||
# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps
|
||||
pip_install onnxscript==0.1.0.dev20240301 --no-deps
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||
|
@ -893,16 +893,6 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||
dtypes=(torch.float16,),
|
||||
reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"),
|
||||
),
|
||||
xfail(
|
||||
"nn.functional.instance_norm",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
reason="fixme: Assertion error: result mismatch",
|
||||
),
|
||||
xfail(
|
||||
"nn.functional.instance_norm",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
reason="Functionalize pass failed",
|
||||
),
|
||||
xfail(
|
||||
"nn.functional.local_response_norm",
|
||||
dtypes=(torch.int64,),
|
||||
@ -1548,6 +1538,13 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||
"Reshape", "empty tensor"
|
||||
),
|
||||
),
|
||||
xfail(
|
||||
"nn.functional.instance_norm",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
matcher=lambda sample: sample.kwargs.get("running_mean") is not None
|
||||
or sample.input.dtype in (torch.float16,),
|
||||
reason="fixme: KeyError: 'self___kwargs__running_mean'",
|
||||
),
|
||||
xfail(
|
||||
"nn.functional.max_pool3d",
|
||||
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
|
||||
@ -1962,6 +1959,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
|
||||
"nn.functional.hardsigmoid": [1e-3, 5e-3],
|
||||
"nn.functional.hardswish": [1e-3, 5e-3],
|
||||
"nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
|
||||
"nn.functional.instance_norm": [1e-2, 1e-3],
|
||||
"nn.functional.interpolate": [1e-2, 1e-3],
|
||||
"nn.functional.kl_div": [2e-3, 2e-4],
|
||||
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
|
||||
|
@ -39,6 +39,15 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
def test_instance_norm(self):
|
||||
def func(x: torch.Tensor):
|
||||
return torch.nn.functional.instance_norm(x)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain an InstanceNormalization op
|
||||
# instead of BatchNormalization op w/ training=True.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -18,6 +18,7 @@ import contextlib
|
||||
from typing import Callable, Sequence, Type
|
||||
|
||||
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
|
||||
core as torchlib_core,
|
||||
nn as torchlib_nn,
|
||||
)
|
||||
|
||||
@ -119,8 +120,61 @@ class UpsampleBilinear2DDecompSkip(DecompSkip):
|
||||
)
|
||||
|
||||
|
||||
class InstanceNormDecompSkip(DecompSkip):
|
||||
op_callable = torch.instance_norm # type: ignore[attr-defined]
|
||||
onnxscript_function = torchlib_core.aten_instance_norm # type: ignore[attr-defined]
|
||||
new_op_name = "instance_norm"
|
||||
new_op_schema = (
|
||||
"(Tensor input, Tensor? weight, Tensor? bias, "
|
||||
"Tensor? running_mean, Tensor? running_var, "
|
||||
"bool use_input_stats, float momentum, float eps, "
|
||||
"bool cudnn_enabled) -> Tensor"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register(cls, export_options: torch.onnx.ExportOptions):
|
||||
if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
|
||||
torch.ops.onnx_export, cls.new_op_name
|
||||
):
|
||||
cls.register_custom_op()
|
||||
|
||||
torch.instance_norm = torch.ops.onnx_export.instance_norm # type: ignore[attr-defined]
|
||||
if export_options.onnx_registry is None:
|
||||
export_options.onnx_registry = torch.onnx.OnnxRegistry()
|
||||
registry = export_options.onnx_registry
|
||||
registry.register_op(
|
||||
function=cls.onnxscript_function,
|
||||
namespace=_NEW_OP_NAMESPACE,
|
||||
op_name=cls.new_op_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls):
|
||||
torch.instance_norm = cls.op_callable # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def abstract(
|
||||
cls,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
use_input_stats: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
cudnn_enabled: bool,
|
||||
):
|
||||
return torch.empty(
|
||||
input.size(),
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SKIP_LIST = [
|
||||
UpsampleBilinear2DDecompSkip,
|
||||
InstanceNormDecompSkip,
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user