Fixing ONNX export of logical ops to have correct output datatype (#15185)

Summary:
Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly.

This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185

Differential Revision: D13494873

Pulled By: houseroad

fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8
This commit is contained in:
Spandan Tiwari
2018-12-20 12:24:42 -08:00
committed by Facebook Github Bot
parent cb0b096f2b
commit f0f9277c3c
10 changed files with 112 additions and 31 deletions

View File

@ -6,13 +6,14 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "0", type:Tensor dims: 3 4}]
outputs: [{name: "4", type:Tensor dims: 0}]
outputs: [{name: "5", type:Tensor dims: 0}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "ATen", inputs: [0,3], outputs: [4], attributes: [{ name: 'operator', type: string, value: 'index'}]}
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
]
}
opset_import: [OperatorSetIdProto { domain: }],

View File

@ -6,20 +6,21 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "4", type:Tensor dims: 1 2 3}]
outputs: [{name: "5", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []},
Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value:
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
outputs: [{name: "5", type:Tensor dims: }]
outputs: [{name: "6", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []}
Node {type: "Neg", inputs: [x.1], outputs: [6], attributes: []}
]
}

View File

@ -6,28 +6,29 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 10}]
outputs: [{name: "8", type:Tensor dims: 10 1}]
outputs: [{name: "9", type:Tensor dims: 10 1}]
initializers: []
nodes: [
Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []},
Node {type: "ReduceSum", inputs: [1], outputs: [2], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [2,3], outputs: [4], attributes: []},
Node {type: "Transpose", inputs: [1], outputs: [5], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "Transpose", inputs: [1], outputs: [6], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "Transpose", inputs: [1], outputs: [7], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "If", inputs: [4], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
Node {type: "Transpose", inputs: [1], outputs: [8], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "If", inputs: [5], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
outputs: [{name: "9", type:Tensor dims: }]
outputs: [{name: "10", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "If", inputs: [4], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
Node {type: "If", inputs: [5], outputs: [10], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
outputs: [{name: "5", type:Tensor dims: }]
outputs: [{name: "6", type:Tensor dims: }]
initializers: []
nodes: [
@ -38,7 +39,7 @@ ModelProto {
GraphProto {
name: "torch-jit-export3"
inputs: []
outputs: [{name: "6", type:Tensor dims: }]
outputs: [{name: "7", type:Tensor dims: }]
initializers: []
nodes: [
@ -53,7 +54,7 @@ ModelProto {
GraphProto {
name: "torch-jit-export4"
inputs: []
outputs: [{name: "7", type:Tensor dims: }]
outputs: [{name: "8", type:Tensor dims: }]
initializers: []
nodes: [

View File

@ -6,28 +6,29 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
outputs: [{name: "7", type:Tensor dims: 1 20}]
outputs: [{name: "8", type:Tensor dims: 1 20}]
initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
nodes: [
Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [4,5], outputs: [6], attributes: []},
Node {type: "If", inputs: [6], outputs: [7], attributes: [{ name: 'then_branch', type: graph, value:
Node {type: "Cast", inputs: [6], outputs: [7], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
outputs: [{name: "8", type:Tensor dims: 1 20}]
outputs: [{name: "9", type:Tensor dims: 1 20}]
initializers: []
nodes: [
Node {type: "If", inputs: [6], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
Node {type: "If", inputs: [7], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
outputs: [{name: "9", type:Tensor dims: 1 20}]
outputs: [{name: "10", type:Tensor dims: 1 20}]
initializers: []
nodes: [
Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
@ -35,10 +36,10 @@ ModelProto {
GraphProto {
name: "torch-jit-export3"
inputs: []
outputs: [{name: "10", type:Tensor dims: 1 20}]
outputs: [{name: "11", type:Tensor dims: 1 20}]
initializers: []
nodes: [
Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
@ -50,10 +51,10 @@ ModelProto {
GraphProto {
name: "torch-jit-export4"
inputs: []
outputs: [{name: "11", type:Tensor dims: 1 20}]
outputs: [{name: "12", type:Tensor dims: 1 20}]
initializers: []
nodes: [
Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}

View File

@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Equal"
}
node {
input: "2"
output: "3"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
@ -48,7 +58,7 @@ graph {
}
}
output {
name: "2"
name: "3"
type {
tensor_type {
elem_type: 2

View File

@ -13,6 +13,16 @@ graph {
output: "3"
op_type: "Not"
}
node {
input: "3"
output: "4"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
@ -47,7 +57,7 @@ graph {
}
}
output {
name: "3"
name: "4"
type {
tensor_type {
elem_type: 2

View File

@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Greater"
}
node {
input: "2"
output: "3"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
@ -48,7 +58,7 @@ graph {
}
}
output {
name: "2"
name: "3"
type {
tensor_type {
elem_type: 2

View File

@ -13,6 +13,16 @@ graph {
output: "3"
op_type: "Not"
}
node {
input: "3"
output: "4"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
@ -47,7 +57,7 @@ graph {
}
}
output {
name: "3"
name: "4"
type {
tensor_type {
elem_type: 2

View File

@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Less"
}
node {
input: "2"
output: "3"
op_type: "Cast"
attribute {
name: "to"
i: 2
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
@ -48,7 +58,7 @@ graph {
}
}
output {
name: "2"
name: "3"
type {
tensor_type {
elem_type: 2

View File

@ -664,24 +664,50 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
mode_s="linear")
def wrap_logical_op_with_cast_to_uint8(func):
def wrap_with_cast(g, input, other):
return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
return wrap_with_cast
def wrap_logical_op_with_negation(func):
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
@wrap_logical_op_with_cast_to_uint8
def gt(g, input, other):
return gt_impl(g, input, other)
def gt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
def lt(g, input, other):
return lt_impl(g, input, other)
def lt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Less", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def ge(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Not", lt(g, input, _if_scalar_type_as(g, other, input)))
return lt_impl(g, input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def le(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Not", gt(g, input, _if_scalar_type_as(g, other, input)))
return gt_impl(g, input, _if_scalar_type_as(g, other, input))
def where(g, condition, self, other):
@ -915,6 +941,7 @@ def min(g, self, dim_or_y=None, keepdim=None):
outputs=2)
@wrap_logical_op_with_cast_to_uint8
def eq(g, self, other):
return g.op("Equal", self, other)