[ONNX] Remove type promotion rule for pow (#139527)

ONNX supports different input types in Pow, so type promotion is not needed.

The resulting graph is the following:

```py
ONNXProgram(
    model=
        <
            ir_version=9,
            opset_imports={'': 18, 'pkg.onnxscript.torch_lib.common': 1},
            producer_name='pytorch',
            producer_version='2.6.0a0+git59a1af5',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT16,[3]>
            ),
            outputs=(
                %"pow_1"<FLOAT16,[3]>
            ),
        ) {
            0 |  # node_Constant_0
                 %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name=None)}
            1 |  # node_Pow_1
                 %"pow_1"<FLOAT16,[3]> ⬅️ ::Pow(%"x", %"val_0")
            return %"pow_1"<FLOAT16,[3]>
        }
...
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, x: "f16[3]"):
                     # File: /workspace/pytorch/test/onnx/exporter/test_small_models_e2e.py:53 in forward, code: return x**2.0
                    pow_1: "f16[3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2.0);  x = None
                    return (pow_1,)

        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='pow_1'), target=None)])
        Range constraints: {}

)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139527
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2024-11-02 02:19:47 +00:00
committed by PyTorch MergeBot
parent 7e65060410
commit 387b120549
3 changed files with 20 additions and 7 deletions

View File

@ -45,7 +45,18 @@ class DynamoExporterTest(common_utils.TestCase):
)
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
onnx_testing.assert_onnx_program(onnx_program)
def test_pow_does_not_trigger_type_promotion(self):
class Model(torch.nn.Module):
def forward(self, x):
return x**2.0
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)
self.assertNotIn("Cast", [node.op_type for node in onnx_program.model.graph])
if __name__ == "__main__":

View File

@ -4,10 +4,18 @@ from torch.onnx._internal.fx.passes import type_promotion
from torch.testing._internal import common_utils
# The following ops are ignored because we do not need these rules enabled for ONNX
IGNORED_OPS = {
"pow",
"pow_",
}
class TestGeneratedTypePromotionRuleSet(common_utils.TestCase):
def test_generated_rule_set_is_up_to_date(self):
generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET
latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs()
latest_set = {rule for rule in latest_set if rule.op_name not in IGNORED_OPS}
# Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction
# if this test fails

View File

@ -885,12 +885,6 @@ _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = {
[],
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
),
ElementwiseTypePromotionRule(
"aten", "pow", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
),
ElementwiseTypePromotionRule(
"aten", "pow_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
),
ElementwiseTypePromotionRule(
"aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),