mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix ONNX ATen fallback for non-caffe2 engines
This PR introduces 3 BC changes: First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type. Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set. Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion. ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954 Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
This commit is contained in:
committed by
PyTorch MergeBot
parent
b142a224c6
commit
9bbe1d632e
@ -247,7 +247,7 @@ namespace c10 {
|
|||||||
_(onnx, Less) \
|
_(onnx, Less) \
|
||||||
_(onnx, LessOrEqual) \
|
_(onnx, LessOrEqual) \
|
||||||
_(onnx, Not) \
|
_(onnx, Not) \
|
||||||
_(onnx, ATen) \
|
_(aten, ATen) \
|
||||||
_(onnx, Split) \
|
_(onnx, Split) \
|
||||||
_(onnx, ConstantOfShape) \
|
_(onnx, ConstantOfShape) \
|
||||||
_(onnx, Cast) \
|
_(onnx, Cast) \
|
||||||
@ -316,7 +316,8 @@ namespace c10 {
|
|||||||
_(attr, new_axis) \
|
_(attr, new_axis) \
|
||||||
_(attr, warn_id) \
|
_(attr, warn_id) \
|
||||||
_(attr, allowzero) \
|
_(attr, allowzero) \
|
||||||
_(attr, seen_none)
|
_(attr, seen_none) \
|
||||||
|
_(attr, overload_name)
|
||||||
|
|
||||||
enum class _keys : unique_t {
|
enum class _keys : unique_t {
|
||||||
#define DEFINE_KEY(ns, s) ns##_##s,
|
#define DEFINE_KEY(ns, s) ns##_##s,
|
||||||
|
|||||||
@ -59,7 +59,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Utilitity to generate Caffe2 benchmark models.")
|
description="Utility to generate Caffe2 benchmark models.")
|
||||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||||
parser.add_argument("-b", "--blob",
|
parser.add_argument("-b", "--blob",
|
||||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||||
|
|||||||
@ -1943,6 +1943,8 @@ if(BUILD_PYTHON)
|
|||||||
# ---[ Python.
|
# ---[ Python.
|
||||||
if(BUILD_CAFFE2)
|
if(BUILD_CAFFE2)
|
||||||
add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS})
|
add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS})
|
||||||
|
target_compile_definitions(torch PRIVATE BUILD_CAFFE2)
|
||||||
|
target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2)
|
||||||
if(USE_NUMPY)
|
if(USE_NUMPY)
|
||||||
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
|
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
|
||||||
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)
|
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class Add(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(g, a, b):
|
def symbolic(g, a, b):
|
||||||
return g.op("ATen", a, b, operator_s = "add")
|
return g.at("add", a, b)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b):
|
def forward(ctx, a, b):
|
||||||
|
|||||||
@ -179,7 +179,7 @@ private:
|
|||||||
std::vector<std::string> attrs;
|
std::vector<std::string> attrs;
|
||||||
for (const auto i : c10::irange(operator_def.arg_size())) {
|
for (const auto i : c10::irange(operator_def.arg_size())) {
|
||||||
auto & attr = operator_def.arg(i);
|
auto & attr = operator_def.arg(i);
|
||||||
if(attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name" ) {
|
if (attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
attrs.push_back(attr.name());
|
attrs.push_back(attr.name());
|
||||||
|
|||||||
@ -1,9 +1,4 @@
|
|||||||
|
from caffe2.python import core
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from caffe2.python import core, dyndep
|
|
||||||
from hypothesis import given
|
from hypothesis import given
|
||||||
|
|
||||||
import caffe2.python.hypothesis_test_util as hu
|
import caffe2.python.hypothesis_test_util as hu
|
||||||
|
|||||||
@ -38,8 +38,8 @@ torch.onnx.export(MyModule(),
|
|||||||
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||||
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||||
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
|
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
|
||||||
# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2)
|
# %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
|
||||||
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y)
|
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
|
||||||
# return (%4)
|
# return (%4)
|
||||||
|
|
||||||
graph = onnx.load(f.name)
|
graph = onnx.load(f.name)
|
||||||
|
|||||||
@ -106,7 +106,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Utilitity to generate Caffe2 benchmark models.")
|
description="Utility to generate Caffe2 benchmark models.")
|
||||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||||
parser.add_argument("-b", "--blob",
|
parser.add_argument("-b", "--blob",
|
||||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||||
|
|||||||
@ -11,7 +11,7 @@ ModelProto {
|
|||||||
nodes: [
|
nodes: [
|
||||||
Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []},
|
Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []},
|
||||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||||
Node {type: "ATen", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
|
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||||
|
|||||||
@ -9,7 +9,7 @@ ModelProto {
|
|||||||
outputs: [{name: "2", type:Tensor dims: 3 4}]
|
outputs: [{name: "2", type:Tensor dims: 3 4}]
|
||||||
initializers: []
|
initializers: []
|
||||||
nodes: [
|
nodes: [
|
||||||
Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
|
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||||
|
|||||||
@ -13,7 +13,7 @@ ModelProto {
|
|||||||
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
|
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
|
||||||
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
|
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||||
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
|
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
|
||||||
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
|
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||||
|
|||||||
@ -18,6 +18,7 @@ graph {
|
|||||||
s: ""
|
s: ""
|
||||||
type: STRING
|
type: STRING
|
||||||
}
|
}
|
||||||
|
domain: "org.pytorch.aten"
|
||||||
}
|
}
|
||||||
name: "torch_jit"
|
name: "torch_jit"
|
||||||
input {
|
input {
|
||||||
@ -56,3 +57,7 @@ graph {
|
|||||||
opset_import {
|
opset_import {
|
||||||
version: 13
|
version: 13
|
||||||
}
|
}
|
||||||
|
opset_import {
|
||||||
|
domain: "org.pytorch.aten"
|
||||||
|
version: 1
|
||||||
|
}
|
||||||
|
|||||||
@ -6,19 +6,24 @@ graph {
|
|||||||
input: "emb.weight"
|
input: "emb.weight"
|
||||||
input: "input_1"
|
input: "input_1"
|
||||||
output: "onnx::Add_3"
|
output: "onnx::Add_3"
|
||||||
name: "ATenOp_0"
|
name: "ATen_0"
|
||||||
op_type: "ATenOp"
|
op_type: "ATen"
|
||||||
attribute {
|
attribute {
|
||||||
name: "custom_attributes_json"
|
name: "custom_attributes_json"
|
||||||
s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
|
s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
|
||||||
type: STRING
|
type: STRING
|
||||||
}
|
}
|
||||||
attribute {
|
attribute {
|
||||||
name: "name"
|
name: "operator"
|
||||||
s: "aten::embedding"
|
s: "embedding"
|
||||||
type: STRING
|
type: STRING
|
||||||
}
|
}
|
||||||
domain: "com.microsoft"
|
attribute {
|
||||||
|
name: "overload_name"
|
||||||
|
s: ""
|
||||||
|
type: STRING
|
||||||
|
}
|
||||||
|
domain: "org.pytorch.aten"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "onnx::Add_3"
|
input: "onnx::Add_3"
|
||||||
@ -145,27 +150,11 @@ graph {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
value_info {
|
|
||||||
name: "onnx::Add_3"
|
|
||||||
type {
|
|
||||||
tensor_type {
|
|
||||||
elem_type: 1
|
|
||||||
shape {
|
|
||||||
dim {
|
|
||||||
dim_param: "ATenOponnx::Add_3_dim_0"
|
|
||||||
}
|
|
||||||
dim {
|
|
||||||
dim_param: "ATenOponnx::Add_3_dim_1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
opset_import {
|
opset_import {
|
||||||
version: 12
|
version: 12
|
||||||
}
|
}
|
||||||
opset_import {
|
opset_import {
|
||||||
domain: "com.microsoft"
|
domain: "org.pytorch.aten"
|
||||||
version: 1
|
version: 1
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,30 +2,9 @@ ir_version: 7
|
|||||||
producer_name: "pytorch"
|
producer_name: "pytorch"
|
||||||
producer_version: "CURRENT_VERSION"
|
producer_version: "CURRENT_VERSION"
|
||||||
graph {
|
graph {
|
||||||
node {
|
|
||||||
output: "onnx::Cast_3"
|
|
||||||
op_type: "Constant"
|
|
||||||
attribute {
|
|
||||||
name: "value"
|
|
||||||
t {
|
|
||||||
data_type: 7
|
|
||||||
raw_data: "\001\000\000\000\000\000\000\000"
|
|
||||||
}
|
|
||||||
type: TENSOR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node {
|
|
||||||
input: "onnx::Cast_3"
|
|
||||||
output: "onnx::Loop_4"
|
|
||||||
op_type: "Cast"
|
|
||||||
attribute {
|
|
||||||
name: "to"
|
|
||||||
i: 9
|
|
||||||
type: INT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node {
|
node {
|
||||||
output: "5"
|
output: "5"
|
||||||
|
name: "Constant_0"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -40,10 +19,12 @@ graph {
|
|||||||
node {
|
node {
|
||||||
input: "input"
|
input: "input"
|
||||||
output: "onnx::Gather_6"
|
output: "onnx::Gather_6"
|
||||||
|
name: "Shape_1"
|
||||||
op_type: "Shape"
|
op_type: "Shape"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Gather_7"
|
output: "onnx::Gather_7"
|
||||||
|
name: "Constant_2"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -58,6 +39,7 @@ graph {
|
|||||||
input: "onnx::Gather_6"
|
input: "onnx::Gather_6"
|
||||||
input: "onnx::Gather_7"
|
input: "onnx::Gather_7"
|
||||||
output: "onnx::Unsqueeze_8"
|
output: "onnx::Unsqueeze_8"
|
||||||
|
name: "Gather_3"
|
||||||
op_type: "Gather"
|
op_type: "Gather"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -67,6 +49,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Unsqueeze_9"
|
output: "onnx::Unsqueeze_9"
|
||||||
|
name: "Constant_4"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -82,12 +65,14 @@ graph {
|
|||||||
input: "onnx::Unsqueeze_8"
|
input: "onnx::Unsqueeze_8"
|
||||||
input: "onnx::Unsqueeze_9"
|
input: "onnx::Unsqueeze_9"
|
||||||
output: "onnx::Concat_10"
|
output: "onnx::Concat_10"
|
||||||
|
name: "Unsqueeze_5"
|
||||||
op_type: "Unsqueeze"
|
op_type: "Unsqueeze"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "offsets"
|
input: "offsets"
|
||||||
input: "onnx::Concat_10"
|
input: "onnx::Concat_10"
|
||||||
output: "onnx::Slice_11"
|
output: "onnx::Slice_11"
|
||||||
|
name: "Concat_6"
|
||||||
op_type: "Concat"
|
op_type: "Concat"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -97,6 +82,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Slice_12"
|
output: "onnx::Slice_12"
|
||||||
|
name: "Constant_7"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -110,6 +96,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Slice_13"
|
output: "onnx::Slice_13"
|
||||||
|
name: "Constant_8"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -123,6 +110,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Slice_14"
|
output: "onnx::Slice_14"
|
||||||
|
name: "Constant_9"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -136,6 +124,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Slice_15"
|
output: "onnx::Slice_15"
|
||||||
|
name: "Constant_10"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -154,15 +143,18 @@ graph {
|
|||||||
input: "onnx::Slice_12"
|
input: "onnx::Slice_12"
|
||||||
input: "onnx::Slice_15"
|
input: "onnx::Slice_15"
|
||||||
output: "onnx::Shape_16"
|
output: "onnx::Shape_16"
|
||||||
|
name: "Slice_11"
|
||||||
op_type: "Slice"
|
op_type: "Slice"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "onnx::Shape_16"
|
input: "onnx::Shape_16"
|
||||||
output: "onnx::Gather_17"
|
output: "onnx::Gather_17"
|
||||||
|
name: "Shape_12"
|
||||||
op_type: "Shape"
|
op_type: "Shape"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "onnx::Gather_18"
|
output: "onnx::Gather_18"
|
||||||
|
name: "Constant_13"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -177,6 +169,7 @@ graph {
|
|||||||
input: "onnx::Gather_17"
|
input: "onnx::Gather_17"
|
||||||
input: "onnx::Gather_18"
|
input: "onnx::Gather_18"
|
||||||
output: "onnx::Loop_19"
|
output: "onnx::Loop_19"
|
||||||
|
name: "Gather_14"
|
||||||
op_type: "Gather"
|
op_type: "Gather"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -186,8 +179,9 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "onnx::Loop_19"
|
input: "onnx::Loop_19"
|
||||||
input: "onnx::Loop_4"
|
input: "onnx::Loop_33"
|
||||||
output: "20"
|
output: "20"
|
||||||
|
name: "Loop_15"
|
||||||
op_type: "Loop"
|
op_type: "Loop"
|
||||||
attribute {
|
attribute {
|
||||||
name: "body"
|
name: "body"
|
||||||
@ -196,7 +190,7 @@ graph {
|
|||||||
input: "onnx::Slice_11"
|
input: "onnx::Slice_11"
|
||||||
input: "21"
|
input: "21"
|
||||||
output: "23"
|
output: "23"
|
||||||
name: "Gather_0"
|
name: "Gather_16"
|
||||||
op_type: "Gather"
|
op_type: "Gather"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -208,7 +202,7 @@ graph {
|
|||||||
input: "onnx::Shape_16"
|
input: "onnx::Shape_16"
|
||||||
input: "21"
|
input: "21"
|
||||||
output: "24"
|
output: "24"
|
||||||
name: "Gather_1"
|
name: "Gather_17"
|
||||||
op_type: "Gather"
|
op_type: "Gather"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -218,7 +212,7 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "25"
|
output: "25"
|
||||||
name: "Constant_2"
|
name: "Constant_18"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -234,12 +228,12 @@ graph {
|
|||||||
input: "23"
|
input: "23"
|
||||||
input: "25"
|
input: "25"
|
||||||
output: "26"
|
output: "26"
|
||||||
name: "Unsqueeze_3"
|
name: "Unsqueeze_19"
|
||||||
op_type: "Unsqueeze"
|
op_type: "Unsqueeze"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
output: "27"
|
output: "27"
|
||||||
name: "Constant_4"
|
name: "Constant_20"
|
||||||
op_type: "Constant"
|
op_type: "Constant"
|
||||||
attribute {
|
attribute {
|
||||||
name: "value"
|
name: "value"
|
||||||
@ -255,7 +249,7 @@ graph {
|
|||||||
input: "24"
|
input: "24"
|
||||||
input: "27"
|
input: "27"
|
||||||
output: "28"
|
output: "28"
|
||||||
name: "Unsqueeze_5"
|
name: "Unsqueeze_21"
|
||||||
op_type: "Unsqueeze"
|
op_type: "Unsqueeze"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
@ -264,14 +258,14 @@ graph {
|
|||||||
input: "28"
|
input: "28"
|
||||||
input: "5"
|
input: "5"
|
||||||
output: "29"
|
output: "29"
|
||||||
name: "Slice_6"
|
name: "Slice_22"
|
||||||
op_type: "Slice"
|
op_type: "Slice"
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "weight"
|
input: "weight"
|
||||||
input: "29"
|
input: "29"
|
||||||
output: "30"
|
output: "30"
|
||||||
name: "Gather_7"
|
name: "Gather_23"
|
||||||
op_type: "Gather"
|
op_type: "Gather"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axis"
|
name: "axis"
|
||||||
@ -282,7 +276,7 @@ graph {
|
|||||||
node {
|
node {
|
||||||
input: "30"
|
input: "30"
|
||||||
output: "31"
|
output: "31"
|
||||||
name: "ReduceMean_8"
|
name: "ReduceMean_24"
|
||||||
op_type: "ReduceMean"
|
op_type: "ReduceMean"
|
||||||
attribute {
|
attribute {
|
||||||
name: "axes"
|
name: "axes"
|
||||||
@ -296,9 +290,9 @@ graph {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "onnx::Loop_4"
|
input: "onnx::Loop_33"
|
||||||
output: "32"
|
output: "32"
|
||||||
name: "Cast_9"
|
name: "Cast_25"
|
||||||
op_type: "Cast"
|
op_type: "Cast"
|
||||||
attribute {
|
attribute {
|
||||||
name: "to"
|
name: "to"
|
||||||
@ -362,6 +356,11 @@ graph {
|
|||||||
name: "weight"
|
name: "weight"
|
||||||
raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?"
|
raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?"
|
||||||
}
|
}
|
||||||
|
initializer {
|
||||||
|
data_type: 9
|
||||||
|
name: "onnx::Loop_33"
|
||||||
|
raw_data: "\001"
|
||||||
|
}
|
||||||
input {
|
input {
|
||||||
name: "input"
|
name: "input"
|
||||||
type {
|
type {
|
||||||
@ -404,6 +403,16 @@ graph {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
input {
|
||||||
|
name: "onnx::Loop_33"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 9
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
output {
|
output {
|
||||||
name: "20"
|
name: "20"
|
||||||
type {
|
type {
|
||||||
|
|||||||
@ -7,6 +7,7 @@ graph {
|
|||||||
input: "weight"
|
input: "weight"
|
||||||
input: "bias"
|
input: "bias"
|
||||||
output: "3"
|
output: "3"
|
||||||
|
name: "ATen_0"
|
||||||
op_type: "ATen"
|
op_type: "ATen"
|
||||||
attribute {
|
attribute {
|
||||||
name: "cudnn_enable"
|
name: "cudnn_enable"
|
||||||
@ -34,6 +35,7 @@ graph {
|
|||||||
s: ""
|
s: ""
|
||||||
type: STRING
|
type: STRING
|
||||||
}
|
}
|
||||||
|
domain: "org.pytorch.aten"
|
||||||
}
|
}
|
||||||
name: "torch_jit"
|
name: "torch_jit"
|
||||||
initializer {
|
initializer {
|
||||||
@ -130,3 +132,7 @@ graph {
|
|||||||
opset_import {
|
opset_import {
|
||||||
version: 13
|
version: 13
|
||||||
}
|
}
|
||||||
|
opset_import {
|
||||||
|
domain: "org.pytorch.aten"
|
||||||
|
version: 1
|
||||||
|
}
|
||||||
|
|||||||
@ -20,12 +20,15 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch.testing._internal.common_utils as common
|
import torch.testing._internal.common_utils as common
|
||||||
|
from torch.testing._internal.common_utils import skipIfCaffe2
|
||||||
|
|
||||||
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
||||||
--no-onnx: no onnx python dependence
|
--no-onnx: no onnx python dependence
|
||||||
--produce-onnx-test-data: generate onnx test data
|
--produce-onnx-test-data: generate onnx test data
|
||||||
--accept: accept onnx updates and overwrite models
|
--accept: accept onnx updates and overwrite models
|
||||||
'''
|
'''
|
||||||
|
import unittest
|
||||||
|
unittest.TestCase.maxDiff = None
|
||||||
|
|
||||||
_onnx_test = False # flag to produce onnx test cases.
|
_onnx_test = False # flag to produce onnx test cases.
|
||||||
_onnx_dep = True # flag to import onnx package.
|
_onnx_dep = True # flag to import onnx package.
|
||||||
@ -322,6 +325,7 @@ class TestOperators(TestCase):
|
|||||||
x = torch.randn(20, 16, 50)
|
x = torch.randn(20, 16, 50)
|
||||||
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
|
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
|
||||||
|
|
||||||
|
@skipIfCaffe2
|
||||||
def test_at_op(self):
|
def test_at_op(self):
|
||||||
x = torch.randn(3, 4)
|
x = torch.randn(3, 4)
|
||||||
|
|
||||||
@ -339,7 +343,8 @@ class TestOperators(TestCase):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return MyFun.apply(x)
|
return MyFun.apply(x)
|
||||||
|
|
||||||
self.assertONNX(MyModule(), x)
|
self.assertONNX(MyModule(), x,
|
||||||
|
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||||
|
|
||||||
def test_clip(self):
|
def test_clip(self):
|
||||||
x = torch.randn(3, 4, requires_grad=True)
|
x = torch.randn(3, 4, requires_grad=True)
|
||||||
@ -588,6 +593,7 @@ class TestOperators(TestCase):
|
|||||||
self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x,
|
self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x,
|
||||||
keep_initializers_as_inputs=True)
|
keep_initializers_as_inputs=True)
|
||||||
|
|
||||||
|
@skipIfCaffe2
|
||||||
def test_embedding_bags(self):
|
def test_embedding_bags(self):
|
||||||
emb_bag = nn.EmbeddingBag(10, 8)
|
emb_bag = nn.EmbeddingBag(10, 8)
|
||||||
input = torch.tensor([1, 2, 3, 4]).long()
|
input = torch.tensor([1, 2, 3, 4]).long()
|
||||||
@ -787,6 +793,7 @@ class TestOperators(TestCase):
|
|||||||
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
||||||
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
|
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
|
||||||
|
|
||||||
|
@skipIfCaffe2
|
||||||
def test_layer_norm_aten(self):
|
def test_layer_norm_aten(self):
|
||||||
model = torch.nn.LayerNorm([10, 10])
|
model = torch.nn.LayerNorm([10, 10])
|
||||||
x = torch.randn(20, 5, 10, 10)
|
x = torch.randn(20, 5, 10, 10)
|
||||||
@ -954,7 +961,7 @@ class TestOperators(TestCase):
|
|||||||
f'"sparse":{str(sparse).lower()}'
|
f'"sparse":{str(sparse).lower()}'
|
||||||
'}'
|
'}'
|
||||||
)
|
)
|
||||||
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
|
output = g.at("embedding", weight, indices,
|
||||||
custom_attributes_json_s=custom_attributes_json)
|
custom_attributes_json_s=custom_attributes_json)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -978,6 +985,7 @@ class TestOperators(TestCase):
|
|||||||
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
||||||
|
|
||||||
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
|
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
|
||||||
|
@skipIfCaffe2
|
||||||
def test_aten_embedding_2(self):
|
def test_aten_embedding_2(self):
|
||||||
_onnx_opset_version = 12
|
_onnx_opset_version = 12
|
||||||
|
|
||||||
@ -990,7 +998,7 @@ class TestOperators(TestCase):
|
|||||||
f'"sparse":{str(sparse).lower()}'
|
f'"sparse":{str(sparse).lower()}'
|
||||||
'}'
|
'}'
|
||||||
)
|
)
|
||||||
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
|
output = g.at("embedding", weight, indices,
|
||||||
custom_attributes_json_s=custom_attributes_json)
|
custom_attributes_json_s=custom_attributes_json)
|
||||||
|
|
||||||
# do shape inference and set it via setType
|
# do shape inference and set it via setType
|
||||||
@ -1016,7 +1024,9 @@ class TestOperators(TestCase):
|
|||||||
x = torch.ones(32, dtype=torch.long)
|
x = torch.ones(32, dtype=torch.long)
|
||||||
y = torch.randn(1, 8)
|
y = torch.randn(1, 8)
|
||||||
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
|
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
|
||||||
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}})
|
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}},
|
||||||
|
keep_initializers_as_inputs=False,
|
||||||
|
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||||
|
|
||||||
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
||||||
|
|
||||||
|
|||||||
@ -62,6 +62,8 @@ from torch.testing._internal.common_quantized import (
|
|||||||
override_qengines,
|
override_qengines,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
from torch.testing._internal.common_utils import skipIfNoCaffe2
|
||||||
|
|
||||||
from hypothesis import given
|
from hypothesis import given
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
@ -1464,6 +1466,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
|
|||||||
onnx_model = export_to_onnx(model, data, input_names)
|
onnx_model = export_to_onnx(model, data, input_names)
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
|
@skipIfNoCaffe2
|
||||||
def test_lower_graph_linear(self):
|
def test_lower_graph_linear(self):
|
||||||
model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
|
model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
|
||||||
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
|
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
|
||||||
@ -1471,6 +1474,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
|
|||||||
self._test_lower_graph_impl(model, data)
|
self._test_lower_graph_impl(model, data)
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
|
@skipIfNoCaffe2
|
||||||
def test_lower_graph_conv2d(self):
|
def test_lower_graph_conv2d(self):
|
||||||
model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
|
model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
|
||||||
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
|
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
|
||||||
|
|||||||
@ -362,7 +362,6 @@ if(USE_NUMPY)
|
|||||||
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2)
|
|
||||||
if(HAVE_SOVERSION)
|
if(HAVE_SOVERSION)
|
||||||
set_target_properties(torch_python PROPERTIES
|
set_target_properties(torch_python PROPERTIES
|
||||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||||
|
|||||||
@ -1905,7 +1905,8 @@ static std::unordered_set<std::string> nodeTypeReliableForTracer = {
|
|||||||
"onnx::Cast",
|
"onnx::Cast",
|
||||||
"onnx::Constant",
|
"onnx::Constant",
|
||||||
"onnx::Relu",
|
"onnx::Relu",
|
||||||
"com.microsoft::Gelu"};
|
"com.microsoft::Gelu",
|
||||||
|
"aten::ATen"};
|
||||||
|
|
||||||
void UpdateReliable(
|
void UpdateReliable(
|
||||||
torch::jit::Value* output,
|
torch::jit::Value* output,
|
||||||
|
|||||||
@ -113,7 +113,7 @@ void validateBlock(
|
|||||||
WithInsertPoint guard(node);
|
WithInsertPoint guard(node);
|
||||||
auto* new_node =
|
auto* new_node =
|
||||||
b->owningGraph()->insertNode(b->owningGraph()->create(
|
b->owningGraph()->insertNode(b->owningGraph()->create(
|
||||||
Symbol(::c10::onnx::ATen),
|
Symbol(::c10::aten::ATen),
|
||||||
node->inputs(),
|
node->inputs(),
|
||||||
node->outputs().size()));
|
node->outputs().size()));
|
||||||
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
||||||
@ -1163,8 +1163,8 @@ void GraphEncoder::EncodeIntermediateValueInfo(
|
|||||||
const Value* v) {
|
const Value* v) {
|
||||||
// Motivation is to encode ValueInfo for onnx local function nodes.
|
// Motivation is to encode ValueInfo for onnx local function nodes.
|
||||||
auto n = v->node();
|
auto n = v->node();
|
||||||
if (n->kind().is_onnx()) {
|
if (n->kind().is_onnx() || n->kind().is_aten()) {
|
||||||
// Encode value info only for non-onnx nodes.
|
// Encode value info only for non-onnx or non-ATen nodes.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (n->owningGraph() != graph_.get()) {
|
if (n->owningGraph() != graph_.get()) {
|
||||||
|
|||||||
@ -329,6 +329,10 @@ def _is_scalar_list(x):
|
|||||||
element_type in scalar_name_to_pytorch.keys() and \
|
element_type in scalar_name_to_pytorch.keys() and \
|
||||||
(scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys())
|
(scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys())
|
||||||
|
|
||||||
|
def is_caffe2_aten_fallback():
|
||||||
|
return (_operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||||
|
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||||
|
|
||||||
def _get_tensor_rank(x):
|
def _get_tensor_rank(x):
|
||||||
if not _is_tensor(x) or x.type() is None:
|
if not _is_tensor(x) or x.type() is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -551,10 +551,7 @@ def arange(g, *args):
|
|||||||
def _dim_arange(g, like, dim):
|
def _dim_arange(g, like, dim):
|
||||||
like_shape = g.op("Shape", like)
|
like_shape = g.op("Shape", like)
|
||||||
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
||||||
# Caffe2-specific op
|
if sym_help.is_caffe2_aten_fallback():
|
||||||
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
|
||||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
|
||||||
if is_caffe2_aten_fallback:
|
|
||||||
return g.op("_caffe2::Range", stop)
|
return g.op("_caffe2::Range", stop)
|
||||||
return arange(g, stop, 4, None, None, None)
|
return arange(g, stop, 4, None, None, None)
|
||||||
|
|
||||||
@ -643,7 +640,8 @@ def index(g, self, index):
|
|||||||
def index_fill(g, self, dim, index, value):
|
def index_fill(g, self, dim, index, value):
|
||||||
dim_value = sym_help._parse_arg(dim, "i")
|
dim_value = sym_help._parse_arg(dim, "i")
|
||||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||||
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
|
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
|
||||||
|
|
||||||
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
||||||
value = sym_help._maybe_get_scalar(value)
|
value = sym_help._maybe_get_scalar(value)
|
||||||
value = sym_help._if_scalar_type_as(g, value, self)
|
value = sym_help._if_scalar_type_as(g, value, self)
|
||||||
|
|||||||
@ -565,7 +565,7 @@ def transpose(g, self, dim0, dim1):
|
|||||||
# if we don't have dim information we cannot
|
# if we don't have dim information we cannot
|
||||||
# output a permute so use ATen instead
|
# output a permute so use ATen instead
|
||||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||||
return g.at("transpose", self, dim0_i=dim0, dim1_i=dim1, overload_name="int")
|
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
|
raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
|
||||||
"of unknown rank.")
|
"of unknown rank.")
|
||||||
@ -1581,7 +1581,8 @@ def index_put(g, self, indices_list_value, values, accumulate):
|
|||||||
def index_fill(g, self, dim, index, value):
|
def index_fill(g, self, dim, index, value):
|
||||||
dim_value = sym_help._parse_arg(dim, "i")
|
dim_value = sym_help._parse_arg(dim, "i")
|
||||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||||
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
|
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
|
||||||
|
|
||||||
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
||||||
value = sym_help._maybe_get_scalar(value)
|
value = sym_help._maybe_get_scalar(value)
|
||||||
value = sym_help._if_scalar_type_as(g, value, self)
|
value = sym_help._if_scalar_type_as(g, value, self)
|
||||||
@ -1647,9 +1648,7 @@ def type_as(g, self, other):
|
|||||||
|
|
||||||
@parse_args("v", "v", "i", "f")
|
@parse_args("v", "v", "i", "f")
|
||||||
def cosine_similarity(g, x1, x2, dim, eps):
|
def cosine_similarity(g, x1, x2, dim, eps):
|
||||||
# preserve legacy behavior for Caffe2
|
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
|
|
||||||
torch.onnx._CAFFE2_ATEN_FALLBACK:
|
|
||||||
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
|
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
|
||||||
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
|
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
|
||||||
axes_i=[dim], keepdims_i=0)
|
axes_i=[dim], keepdims_i=0)
|
||||||
@ -2599,10 +2598,7 @@ rnn_relu = _one_hidden_rnn("RNN_RELU")
|
|||||||
def _dim_arange(g, like, dim):
|
def _dim_arange(g, like, dim):
|
||||||
like_shape = g.op("Shape", like)
|
like_shape = g.op("Shape", like)
|
||||||
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
||||||
# Caffe2-specific op
|
if sym_help.is_caffe2_aten_fallback():
|
||||||
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
|
||||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
|
||||||
if is_caffe2_aten_fallback:
|
|
||||||
return g.op("_caffe2::Range", stop)
|
return g.op("_caffe2::Range", stop)
|
||||||
else:
|
else:
|
||||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||||
|
|||||||
@ -190,7 +190,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||||||
torch._C._jit_pass_peephole(graph, True)
|
torch._C._jit_pass_peephole(graph, True)
|
||||||
torch._C._jit_pass_fuse_addmm(graph)
|
torch._C._jit_pass_fuse_addmm(graph)
|
||||||
torch._C._jit_pass_lint(graph)
|
torch._C._jit_pass_lint(graph)
|
||||||
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
|
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version, is_caffe2_aten_fallback
|
||||||
|
|
||||||
torch._C._jit_pass_peephole(graph, True)
|
torch._C._jit_pass_peephole(graph, True)
|
||||||
torch._C._jit_pass_lower_all_tuples(graph)
|
torch._C._jit_pass_lower_all_tuples(graph)
|
||||||
@ -212,13 +212,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||||||
torch._C._jit_pass_onnx_remove_print(graph)
|
torch._C._jit_pass_onnx_remove_print(graph)
|
||||||
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
|
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
|
||||||
|
|
||||||
# Caffe2-specific optimization
|
|
||||||
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
|
||||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
|
||||||
torch.onnx.symbolic_helper._quantized_ops.clear()
|
torch.onnx.symbolic_helper._quantized_ops.clear()
|
||||||
# Unpack quantized weights for conv and linear ops and insert into graph.
|
# Unpack quantized weights for conv and linear ops and insert into graph.
|
||||||
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback)
|
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback())
|
||||||
if is_caffe2_aten_fallback:
|
if is_caffe2_aten_fallback():
|
||||||
# Insert permutes before and after each conv op to ensure correct order.
|
# Insert permutes before and after each conv op to ensure correct order.
|
||||||
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
|
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
|
||||||
|
|
||||||
@ -289,7 +286,7 @@ def warn_on_static_input_change(input_states):
|
|||||||
|
|
||||||
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||||
# This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX
|
# This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX
|
||||||
if operator_export_type is not operator_export_type.ONNX:
|
if operator_export_type is not operator_export_type.ONNX and torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||||
if arg_value is True:
|
if arg_value is True:
|
||||||
warnings.warn("`{}' can be set to True only when 'operator_export_type' is "
|
warnings.warn("`{}' can be set to True only when 'operator_export_type' is "
|
||||||
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
|
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
|
||||||
@ -959,7 +956,8 @@ def _add_attribute(node, key, value, aten):
|
|||||||
name, kind = m.group(1), m.group(2)
|
name, kind = m.group(1), m.group(2)
|
||||||
if _is_onnx_list(value):
|
if _is_onnx_list(value):
|
||||||
kind += "s"
|
kind += "s"
|
||||||
if aten:
|
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
|
||||||
|
if aten and is_caffe2_aten_fallback():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
# Caffe2 proto does not support tensor attribute.
|
# Caffe2 proto does not support tensor attribute.
|
||||||
if value.numel() > 1:
|
if value.numel() > 1:
|
||||||
@ -1119,14 +1117,13 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
|||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
|
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
|
||||||
|
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
|
||||||
import torch.onnx.symbolic_registry as sym_registry
|
import torch.onnx.symbolic_registry as sym_registry
|
||||||
|
|
||||||
sym_registry.register_version("", opset_version)
|
sym_registry.register_version("", opset_version)
|
||||||
|
|
||||||
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
||||||
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
if is_caffe2_aten_fallback() and opset_version == 9:
|
||||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
|
||||||
if is_caffe2_aten_fallback and opset_version == 9:
|
|
||||||
import torch.onnx.symbolic_caffe2
|
import torch.onnx.symbolic_caffe2
|
||||||
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
||||||
|
|
||||||
@ -1137,11 +1134,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
|||||||
else:
|
else:
|
||||||
ns_op_name = n.kind()
|
ns_op_name = n.kind()
|
||||||
ns, op_name = ns_op_name.split("::")
|
ns, op_name = ns_op_name.split("::")
|
||||||
|
|
||||||
domain = ns
|
domain = ns
|
||||||
if ns == "aten":
|
if ns == "aten":
|
||||||
domain = ""
|
domain = ""
|
||||||
elif ns == "quantized" and is_caffe2_aten_fallback:
|
elif ns == "quantized" and is_caffe2_aten_fallback():
|
||||||
domain = "caffe2"
|
domain = "caffe2"
|
||||||
|
|
||||||
if sym_registry.is_registered_op(op_name, domain, opset_version):
|
if sym_registry.is_registered_op(op_name, domain, opset_version):
|
||||||
@ -1165,7 +1161,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
|||||||
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
|
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
|
||||||
outputs = n.outputsSize()
|
outputs = n.outputsSize()
|
||||||
attrs["outputs"] = outputs
|
attrs["outputs"] = outputs
|
||||||
return g.at(op_name, *inputs, aten=True, **attrs)
|
return g.at(op_name, *inputs, **attrs)
|
||||||
else:
|
else:
|
||||||
raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
|
raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -1181,6 +1177,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
|||||||
|
|
||||||
# Generate an ONNX ATen op node.
|
# Generate an ONNX ATen op node.
|
||||||
def _aten_op(g, operator, *args, overload_name="", **kwargs):
|
def _aten_op(g, operator, *args, overload_name="", **kwargs):
|
||||||
|
kwargs["aten"] = True
|
||||||
return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs)
|
return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -1315,7 +1312,7 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
|||||||
|
|
||||||
|
|
||||||
torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
|
torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
|
||||||
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
|
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
|
||||||
torch._C.Block.op = _block_op # type: ignore[attr-defined]
|
torch._C.Block.op = _block_op # type: ignore[attr-defined]
|
||||||
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
|
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
|
||||||
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]
|
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]
|
||||||
|
|||||||
@ -295,6 +295,14 @@ def skipIfNoQNNPACK(fn):
|
|||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||||
|
raise unittest.SkipTest(reason)
|
||||||
|
else:
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torchvision # noqa: F401
|
import torchvision # noqa: F401
|
||||||
HAS_TORCHVISION = True
|
HAS_TORCHVISION = True
|
||||||
|
|||||||
@ -1023,7 +1023,6 @@ def skipIfNoLapack(fn):
|
|||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def skipIfNotRegistered(op_name, message):
|
def skipIfNotRegistered(op_name, message):
|
||||||
"""Wraps the decorator to hide the import of the `core`.
|
"""Wraps the decorator to hide the import of the `core`.
|
||||||
|
|
||||||
@ -1045,6 +1044,17 @@ def skipIfNotRegistered(op_name, message):
|
|||||||
skipper = unittest.skip("Cannot import `caffe2.python.core`")
|
skipper = unittest.skip("Cannot import `caffe2.python.core`")
|
||||||
return skipper
|
return skipper
|
||||||
|
|
||||||
|
def _decide_skip_caffe2(expect_caffe2, reason):
|
||||||
|
def skip_dec(func):
|
||||||
|
def wrapper(self):
|
||||||
|
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
|
||||||
|
raise unittest.SkipTest(reason)
|
||||||
|
return func(self)
|
||||||
|
return wrapper
|
||||||
|
return skip_dec
|
||||||
|
|
||||||
|
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
|
||||||
|
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
|
||||||
|
|
||||||
def skipIfNoSciPy(fn):
|
def skipIfNoSciPy(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
|||||||
Reference in New Issue
Block a user