mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixing ONNX export of logical ops to have correct output datatype (#15185)
Summary: Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly. This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185 Differential Revision: D13494873 Pulled By: houseroad fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8
This commit is contained in:
committed by
Facebook Github Bot
parent
cb0b096f2b
commit
f0f9277c3c
@ -6,13 +6,14 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export"
|
||||
inputs: [{name: "0", type:Tensor dims: 3 4}]
|
||||
outputs: [{name: "4", type:Tensor dims: 0}]
|
||||
outputs: [{name: "5", type:Tensor dims: 0}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
|
||||
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "ATen", inputs: [0,3], outputs: [4], attributes: [{ name: 'operator', type: string, value: 'index'}]}
|
||||
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
|
||||
]
|
||||
}
|
||||
opset_import: [OperatorSetIdProto { domain: }],
|
||||
|
@ -6,20 +6,21 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export"
|
||||
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
|
||||
outputs: [{name: "4", type:Tensor dims: 1 2 3}]
|
||||
outputs: [{name: "5", type:Tensor dims: 1 2 3}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []},
|
||||
Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: []
|
||||
outputs: [{name: "5", type:Tensor dims: }]
|
||||
outputs: [{name: "6", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []}
|
||||
Node {type: "Neg", inputs: [x.1], outputs: [6], attributes: []}
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -6,28 +6,29 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export"
|
||||
inputs: [{name: "x.1", type:Tensor dims: 1 10}]
|
||||
outputs: [{name: "8", type:Tensor dims: 10 1}]
|
||||
outputs: [{name: "9", type:Tensor dims: 10 1}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []},
|
||||
Node {type: "ReduceSum", inputs: [1], outputs: [2], attributes: [{ name: 'keepdims', type: int, value: 0}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Greater", inputs: [2,3], outputs: [4], attributes: []},
|
||||
Node {type: "Transpose", inputs: [1], outputs: [5], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
|
||||
Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "Transpose", inputs: [1], outputs: [6], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
|
||||
Node {type: "Transpose", inputs: [1], outputs: [7], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
|
||||
Node {type: "If", inputs: [4], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
Node {type: "Transpose", inputs: [1], outputs: [8], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
|
||||
Node {type: "If", inputs: [5], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: []
|
||||
outputs: [{name: "9", type:Tensor dims: }]
|
||||
outputs: [{name: "10", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "If", inputs: [4], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
Node {type: "If", inputs: [5], outputs: [10], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export2"
|
||||
inputs: []
|
||||
outputs: [{name: "5", type:Tensor dims: }]
|
||||
outputs: [{name: "6", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
|
||||
@ -38,7 +39,7 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export3"
|
||||
inputs: []
|
||||
outputs: [{name: "6", type:Tensor dims: }]
|
||||
outputs: [{name: "7", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
|
||||
@ -53,7 +54,7 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export4"
|
||||
inputs: []
|
||||
outputs: [{name: "7", type:Tensor dims: }]
|
||||
outputs: [{name: "8", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
|
||||
|
@ -6,28 +6,29 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export"
|
||||
inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
|
||||
outputs: [{name: "7", type:Tensor dims: 1 20}]
|
||||
outputs: [{name: "8", type:Tensor dims: 1 20}]
|
||||
initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
|
||||
nodes: [
|
||||
Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
|
||||
Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Greater", inputs: [4,5], outputs: [6], attributes: []},
|
||||
Node {type: "If", inputs: [6], outputs: [7], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
Node {type: "Cast", inputs: [6], outputs: [7], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: []
|
||||
outputs: [{name: "8", type:Tensor dims: 1 20}]
|
||||
outputs: [{name: "9", type:Tensor dims: 1 20}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "If", inputs: [6], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
Node {type: "If", inputs: [7], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export2"
|
||||
inputs: []
|
||||
outputs: [{name: "9", type:Tensor dims: 1 20}]
|
||||
outputs: [{name: "10", type:Tensor dims: 1 20}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
]
|
||||
}
|
||||
|
||||
@ -35,10 +36,10 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export3"
|
||||
inputs: []
|
||||
outputs: [{name: "10", type:Tensor dims: 1 20}]
|
||||
outputs: [{name: "11", type:Tensor dims: 1 20}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
]
|
||||
}
|
||||
|
||||
@ -50,10 +51,10 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export4"
|
||||
inputs: []
|
||||
outputs: [{name: "11", type:Tensor dims: 1 20}]
|
||||
outputs: [{name: "12", type:Tensor dims: 1 20}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,16 @@ graph {
|
||||
output: "2"
|
||||
op_type: "Equal"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
output: "3"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
@ -48,7 +58,7 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "2"
|
||||
name: "3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
|
@ -13,6 +13,16 @@ graph {
|
||||
output: "3"
|
||||
op_type: "Not"
|
||||
}
|
||||
node {
|
||||
input: "3"
|
||||
output: "4"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
@ -47,7 +57,7 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "3"
|
||||
name: "4"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
|
@ -8,6 +8,16 @@ graph {
|
||||
output: "2"
|
||||
op_type: "Greater"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
output: "3"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
@ -48,7 +58,7 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "2"
|
||||
name: "3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
|
@ -13,6 +13,16 @@ graph {
|
||||
output: "3"
|
||||
op_type: "Not"
|
||||
}
|
||||
node {
|
||||
input: "3"
|
||||
output: "4"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
@ -47,7 +57,7 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "3"
|
||||
name: "4"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
|
@ -8,6 +8,16 @@ graph {
|
||||
output: "2"
|
||||
op_type: "Less"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
output: "3"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
@ -48,7 +58,7 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "2"
|
||||
name: "3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
|
@ -664,24 +664,50 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
|
||||
mode_s="linear")
|
||||
|
||||
|
||||
def wrap_logical_op_with_cast_to_uint8(func):
|
||||
def wrap_with_cast(g, input, other):
|
||||
return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
|
||||
return wrap_with_cast
|
||||
|
||||
|
||||
def wrap_logical_op_with_negation(func):
|
||||
def wrap_with_not(g, input, other):
|
||||
return g.op("Not", func(g, input, other))
|
||||
return wrap_with_not
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to_uint8
|
||||
def gt(g, input, other):
|
||||
return gt_impl(g, input, other)
|
||||
|
||||
|
||||
def gt_impl(g, input, other):
|
||||
other = _maybe_get_scalar(other)
|
||||
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to_uint8
|
||||
def lt(g, input, other):
|
||||
return lt_impl(g, input, other)
|
||||
|
||||
|
||||
def lt_impl(g, input, other):
|
||||
other = _maybe_get_scalar(other)
|
||||
return g.op("Less", input, _if_scalar_type_as(g, other, input))
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to_uint8
|
||||
@wrap_logical_op_with_negation
|
||||
def ge(g, input, other):
|
||||
other = _maybe_get_scalar(other)
|
||||
return g.op("Not", lt(g, input, _if_scalar_type_as(g, other, input)))
|
||||
return lt_impl(g, input, _if_scalar_type_as(g, other, input))
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to_uint8
|
||||
@wrap_logical_op_with_negation
|
||||
def le(g, input, other):
|
||||
other = _maybe_get_scalar(other)
|
||||
return g.op("Not", gt(g, input, _if_scalar_type_as(g, other, input)))
|
||||
return gt_impl(g, input, _if_scalar_type_as(g, other, input))
|
||||
|
||||
|
||||
def where(g, condition, self, other):
|
||||
@ -915,6 +941,7 @@ def min(g, self, dim_or_y=None, keepdim=None):
|
||||
outputs=2)
|
||||
|
||||
|
||||
@wrap_logical_op_with_cast_to_uint8
|
||||
def eq(g, self, other):
|
||||
return g.op("Equal", self, other)
|
||||
|
||||
|
Reference in New Issue
Block a user