mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Remove type promotion rule for pow (#139527)
ONNX supports different input types in Pow, so type promotion is not needed. The resulting graph is the following: ```py ONNXProgram( model= < ir_version=9, opset_imports={'': 18, 'pkg.onnxscript.torch_lib.common': 1}, producer_name='pytorch', producer_version='2.6.0a0+git59a1af5', domain=None, model_version=None, > graph( name=main_graph, inputs=( %"x"<FLOAT16,[3]> ), outputs=( %"pow_1"<FLOAT16,[3]> ), ) { 0 | # node_Constant_0 %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name=None)} 1 | # node_Pow_1 %"pow_1"<FLOAT16,[3]> ⬅️ ::Pow(%"x", %"val_0") return %"pow_1"<FLOAT16,[3]> } ... , exported_program= ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f16[3]"): # File: /workspace/pytorch/test/onnx/exporter/test_small_models_e2e.py:53 in forward, code: return x**2.0 pow_1: "f16[3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2.0); x = None return (pow_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='pow_1'), target=None)]) Range constraints: {} ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/139527 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e65060410
commit
387b120549
@ -45,7 +45,18 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True)
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_pow_does_not_trigger_type_promotion(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x**2.0
|
||||
|
||||
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
|
||||
|
||||
onnx_program = torch.onnx.export(Model(), (x,), dynamo=True)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
self.assertNotIn("Cast", [node.op_type for node in onnx_program.model.graph])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4,10 +4,18 @@ from torch.onnx._internal.fx.passes import type_promotion
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
# The following ops are ignored because we do not need these rules enabled for ONNX
|
||||
IGNORED_OPS = {
|
||||
"pow",
|
||||
"pow_",
|
||||
}
|
||||
|
||||
|
||||
class TestGeneratedTypePromotionRuleSet(common_utils.TestCase):
|
||||
def test_generated_rule_set_is_up_to_date(self):
|
||||
generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET
|
||||
latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs()
|
||||
latest_set = {rule for rule in latest_set if rule.op_name not in IGNORED_OPS}
|
||||
|
||||
# Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction
|
||||
# if this test fails
|
||||
|
@ -885,12 +885,6 @@ _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = {
|
||||
[],
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
),
|
||||
ElementwiseTypePromotionRule(
|
||||
"aten", "pow", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
|
||||
),
|
||||
ElementwiseTypePromotionRule(
|
||||
"aten", "pow_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
|
||||
),
|
||||
ElementwiseTypePromotionRule(
|
||||
"aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
),
|
||||
|
Reference in New Issue
Block a user