Fix ONNX ATen fallback for non-caffe2 engines

This PR introduces 3 BC changes:

First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type.

Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set.

Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion.
ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954
Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
This commit is contained in:
Thiago Crepaldi
2022-04-14 23:18:45 +00:00
committed by PyTorch MergeBot
parent b142a224c6
commit 9bbe1d632e
26 changed files with 146 additions and 112 deletions

View File

@ -247,7 +247,7 @@ namespace c10 {
_(onnx, Less) \
_(onnx, LessOrEqual) \
_(onnx, Not) \
_(onnx, ATen) \
_(aten, ATen) \
_(onnx, Split) \
_(onnx, ConstantOfShape) \
_(onnx, Cast) \
@ -316,7 +316,8 @@ namespace c10 {
_(attr, new_axis) \
_(attr, warn_id) \
_(attr, allowzero) \
_(attr, seen_none)
_(attr, seen_none) \
_(attr, overload_name)
enum class _keys : unique_t {
#define DEFINE_KEY(ns, s) ns##_##s,

View File

@ -59,7 +59,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Utilitity to generate Caffe2 benchmark models.")
description="Utility to generate Caffe2 benchmark models.")
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
parser.add_argument("-b", "--blob",
help="Instantiate a blob --blob name=dim1,dim2,dim3",

View File

@ -1943,6 +1943,8 @@ if(BUILD_PYTHON)
# ---[ Python.
if(BUILD_CAFFE2)
add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS})
target_compile_definitions(torch PRIVATE BUILD_CAFFE2)
target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2)
if(USE_NUMPY)
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)

View File

@ -72,7 +72,7 @@ class Add(torch.autograd.Function):
@staticmethod
def symbolic(g, a, b):
return g.op("ATen", a, b, operator_s = "add")
return g.at("add", a, b)
@staticmethod
def forward(ctx, a, b):

View File

@ -179,7 +179,7 @@ private:
std::vector<std::string> attrs;
for (const auto i : c10::irange(operator_def.arg_size())) {
auto & attr = operator_def.arg(i);
if(attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name" ) {
if (attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name") {
continue;
}
attrs.push_back(attr.name());

View File

@ -1,9 +1,4 @@
from caffe2.python import core, dyndep
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu

View File

@ -38,8 +38,8 @@ torch.onnx.export(MyModule(),
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2)
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y)
# %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
# return (%4)
graph = onnx.load(f.name)

View File

@ -106,7 +106,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Utilitity to generate Caffe2 benchmark models.")
description="Utility to generate Caffe2 benchmark models.")
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
parser.add_argument("-b", "--blob",
help="Instantiate a blob --blob name=dim1,dim2,dim3",

View File

@ -11,7 +11,7 @@ ModelProto {
nodes: [
Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "ATen", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
]
}
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -9,7 +9,7 @@ ModelProto {
outputs: [{name: "2", type:Tensor dims: 3 4}]
initializers: []
nodes: [
Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
]
}
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -13,7 +13,7 @@ ModelProto {
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
]
}
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -18,6 +18,7 @@ graph {
s: ""
type: STRING
}
domain: "org.pytorch.aten"
}
name: "torch_jit"
input {
@ -56,3 +57,7 @@ graph {
opset_import {
version: 13
}
opset_import {
domain: "org.pytorch.aten"
version: 1
}

View File

@ -6,19 +6,24 @@ graph {
input: "emb.weight"
input: "input_1"
output: "onnx::Add_3"
name: "ATenOp_0"
op_type: "ATenOp"
name: "ATen_0"
op_type: "ATen"
attribute {
name: "custom_attributes_json"
s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
type: STRING
}
attribute {
name: "name"
s: "aten::embedding"
name: "operator"
s: "embedding"
type: STRING
}
domain: "com.microsoft"
attribute {
name: "overload_name"
s: ""
type: STRING
}
domain: "org.pytorch.aten"
}
node {
input: "onnx::Add_3"
@ -145,27 +150,11 @@ graph {
}
}
}
value_info {
name: "onnx::Add_3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "ATenOponnx::Add_3_dim_0"
}
dim {
dim_param: "ATenOponnx::Add_3_dim_1"
}
}
}
}
}
}
opset_import {
version: 12
}
opset_import {
domain: "com.microsoft"
domain: "org.pytorch.aten"
version: 1
}

View File

@ -2,30 +2,9 @@ ir_version: 7
producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
output: "onnx::Cast_3"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Cast_3"
output: "onnx::Loop_4"
op_type: "Cast"
attribute {
name: "to"
i: 9
type: INT
}
}
node {
output: "5"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
@ -40,10 +19,12 @@ graph {
node {
input: "input"
output: "onnx::Gather_6"
name: "Shape_1"
op_type: "Shape"
}
node {
output: "onnx::Gather_7"
name: "Constant_2"
op_type: "Constant"
attribute {
name: "value"
@ -58,6 +39,7 @@ graph {
input: "onnx::Gather_6"
input: "onnx::Gather_7"
output: "onnx::Unsqueeze_8"
name: "Gather_3"
op_type: "Gather"
attribute {
name: "axis"
@ -67,6 +49,7 @@ graph {
}
node {
output: "onnx::Unsqueeze_9"
name: "Constant_4"
op_type: "Constant"
attribute {
name: "value"
@ -82,12 +65,14 @@ graph {
input: "onnx::Unsqueeze_8"
input: "onnx::Unsqueeze_9"
output: "onnx::Concat_10"
name: "Unsqueeze_5"
op_type: "Unsqueeze"
}
node {
input: "offsets"
input: "onnx::Concat_10"
output: "onnx::Slice_11"
name: "Concat_6"
op_type: "Concat"
attribute {
name: "axis"
@ -97,6 +82,7 @@ graph {
}
node {
output: "onnx::Slice_12"
name: "Constant_7"
op_type: "Constant"
attribute {
name: "value"
@ -110,6 +96,7 @@ graph {
}
node {
output: "onnx::Slice_13"
name: "Constant_8"
op_type: "Constant"
attribute {
name: "value"
@ -123,6 +110,7 @@ graph {
}
node {
output: "onnx::Slice_14"
name: "Constant_9"
op_type: "Constant"
attribute {
name: "value"
@ -136,6 +124,7 @@ graph {
}
node {
output: "onnx::Slice_15"
name: "Constant_10"
op_type: "Constant"
attribute {
name: "value"
@ -154,15 +143,18 @@ graph {
input: "onnx::Slice_12"
input: "onnx::Slice_15"
output: "onnx::Shape_16"
name: "Slice_11"
op_type: "Slice"
}
node {
input: "onnx::Shape_16"
output: "onnx::Gather_17"
name: "Shape_12"
op_type: "Shape"
}
node {
output: "onnx::Gather_18"
name: "Constant_13"
op_type: "Constant"
attribute {
name: "value"
@ -177,6 +169,7 @@ graph {
input: "onnx::Gather_17"
input: "onnx::Gather_18"
output: "onnx::Loop_19"
name: "Gather_14"
op_type: "Gather"
attribute {
name: "axis"
@ -186,8 +179,9 @@ graph {
}
node {
input: "onnx::Loop_19"
input: "onnx::Loop_4"
input: "onnx::Loop_33"
output: "20"
name: "Loop_15"
op_type: "Loop"
attribute {
name: "body"
@ -196,7 +190,7 @@ graph {
input: "onnx::Slice_11"
input: "21"
output: "23"
name: "Gather_0"
name: "Gather_16"
op_type: "Gather"
attribute {
name: "axis"
@ -208,7 +202,7 @@ graph {
input: "onnx::Shape_16"
input: "21"
output: "24"
name: "Gather_1"
name: "Gather_17"
op_type: "Gather"
attribute {
name: "axis"
@ -218,7 +212,7 @@ graph {
}
node {
output: "25"
name: "Constant_2"
name: "Constant_18"
op_type: "Constant"
attribute {
name: "value"
@ -234,12 +228,12 @@ graph {
input: "23"
input: "25"
output: "26"
name: "Unsqueeze_3"
name: "Unsqueeze_19"
op_type: "Unsqueeze"
}
node {
output: "27"
name: "Constant_4"
name: "Constant_20"
op_type: "Constant"
attribute {
name: "value"
@ -255,7 +249,7 @@ graph {
input: "24"
input: "27"
output: "28"
name: "Unsqueeze_5"
name: "Unsqueeze_21"
op_type: "Unsqueeze"
}
node {
@ -264,14 +258,14 @@ graph {
input: "28"
input: "5"
output: "29"
name: "Slice_6"
name: "Slice_22"
op_type: "Slice"
}
node {
input: "weight"
input: "29"
output: "30"
name: "Gather_7"
name: "Gather_23"
op_type: "Gather"
attribute {
name: "axis"
@ -282,7 +276,7 @@ graph {
node {
input: "30"
output: "31"
name: "ReduceMean_8"
name: "ReduceMean_24"
op_type: "ReduceMean"
attribute {
name: "axes"
@ -296,9 +290,9 @@ graph {
}
}
node {
input: "onnx::Loop_4"
input: "onnx::Loop_33"
output: "32"
name: "Cast_9"
name: "Cast_25"
op_type: "Cast"
attribute {
name: "to"
@ -362,6 +356,11 @@ graph {
name: "weight"
raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?"
}
initializer {
data_type: 9
name: "onnx::Loop_33"
raw_data: "\001"
}
input {
name: "input"
type {
@ -404,6 +403,16 @@ graph {
}
}
}
input {
name: "onnx::Loop_33"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
output {
name: "20"
type {

View File

@ -7,6 +7,7 @@ graph {
input: "weight"
input: "bias"
output: "3"
name: "ATen_0"
op_type: "ATen"
attribute {
name: "cudnn_enable"
@ -34,6 +35,7 @@ graph {
s: ""
type: STRING
}
domain: "org.pytorch.aten"
}
name: "torch_jit"
initializer {
@ -130,3 +132,7 @@ graph {
opset_import {
version: 13
}
opset_import {
domain: "org.pytorch.aten"
version: 1
}

View File

@ -20,12 +20,15 @@ import os
import shutil
import tempfile
import torch.testing._internal.common_utils as common
from torch.testing._internal.common_utils import skipIfCaffe2
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--no-onnx: no onnx python dependence
--produce-onnx-test-data: generate onnx test data
--accept: accept onnx updates and overwrite models
'''
import unittest
unittest.TestCase.maxDiff = None
_onnx_test = False # flag to produce onnx test cases.
_onnx_dep = True # flag to import onnx package.
@ -322,6 +325,7 @@ class TestOperators(TestCase):
x = torch.randn(20, 16, 50)
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
@skipIfCaffe2
def test_at_op(self):
x = torch.randn(3, 4)
@ -339,7 +343,8 @@ class TestOperators(TestCase):
def forward(self, x):
return MyFun.apply(x)
self.assertONNX(MyModule(), x)
self.assertONNX(MyModule(), x,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
def test_clip(self):
x = torch.randn(3, 4, requires_grad=True)
@ -588,6 +593,7 @@ class TestOperators(TestCase):
self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x,
keep_initializers_as_inputs=True)
@skipIfCaffe2
def test_embedding_bags(self):
emb_bag = nn.EmbeddingBag(10, 8)
input = torch.tensor([1, 2, 3, 4]).long()
@ -787,6 +793,7 @@ class TestOperators(TestCase):
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
@skipIfCaffe2
def test_layer_norm_aten(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
@ -954,7 +961,7 @@ class TestOperators(TestCase):
f'"sparse":{str(sparse).lower()}'
'}'
)
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
output = g.at("embedding", weight, indices,
custom_attributes_json_s=custom_attributes_json)
return output
@ -978,6 +985,7 @@ class TestOperators(TestCase):
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
@skipIfCaffe2
def test_aten_embedding_2(self):
_onnx_opset_version = 12
@ -990,7 +998,7 @@ class TestOperators(TestCase):
f'"sparse":{str(sparse).lower()}'
'}'
)
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
output = g.at("embedding", weight, indices,
custom_attributes_json_s=custom_attributes_json)
# do shape inference and set it via setType
@ -1016,7 +1024,9 @@ class TestOperators(TestCase):
x = torch.ones(32, dtype=torch.long)
y = torch.randn(1, 8)
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}})
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}},
keep_initializers_as_inputs=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)

View File

@ -62,6 +62,8 @@ from torch.testing._internal.common_quantized import (
override_qengines,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoCaffe2
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
@ -1464,6 +1466,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
onnx_model = export_to_onnx(model, data, input_names)
@skipIfNoFBGEMM
@skipIfNoCaffe2
def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
@ -1471,6 +1474,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
self._test_lower_graph_impl(model, data)
@skipIfNoFBGEMM
@skipIfNoCaffe2
def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)

View File

@ -362,7 +362,6 @@ if(USE_NUMPY)
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
endif()
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2)
if(HAVE_SOVERSION)
set_target_properties(torch_python PROPERTIES
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})

View File

@ -1905,7 +1905,8 @@ static std::unordered_set<std::string> nodeTypeReliableForTracer = {
"onnx::Cast",
"onnx::Constant",
"onnx::Relu",
"com.microsoft::Gelu"};
"com.microsoft::Gelu",
"aten::ATen"};
void UpdateReliable(
torch::jit::Value* output,

View File

@ -113,7 +113,7 @@ void validateBlock(
WithInsertPoint guard(node);
auto* new_node =
b->owningGraph()->insertNode(b->owningGraph()->create(
Symbol(::c10::onnx::ATen),
Symbol(::c10::aten::ATen),
node->inputs(),
node->outputs().size()));
for (size_t i = 0; i < node->outputs().size(); ++i) {
@ -1163,8 +1163,8 @@ void GraphEncoder::EncodeIntermediateValueInfo(
const Value* v) {
// Motivation is to encode ValueInfo for onnx local function nodes.
auto n = v->node();
if (n->kind().is_onnx()) {
// Encode value info only for non-onnx nodes.
if (n->kind().is_onnx() || n->kind().is_aten()) {
// Encode value info only for non-onnx or non-ATen nodes.
return;
}
if (n->owningGraph() != graph_.get()) {

View File

@ -329,6 +329,10 @@ def _is_scalar_list(x):
element_type in scalar_name_to_pytorch.keys() and \
(scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys())
def is_caffe2_aten_fallback():
return (_operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
def _get_tensor_rank(x):
if not _is_tensor(x) or x.type() is None:
return None

View File

@ -551,10 +551,7 @@ def arange(g, *args):
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
# Caffe2-specific op
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
if sym_help.is_caffe2_aten_fallback():
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)
@ -643,7 +640,8 @@ def index(g, self, index):
def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self)

View File

@ -565,7 +565,7 @@ def transpose(g, self, dim0, dim1):
# if we don't have dim information we cannot
# output a permute so use ATen instead
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("transpose", self, dim0_i=dim0, dim1_i=dim1, overload_name="int")
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
else:
raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
"of unknown rank.")
@ -1581,7 +1581,8 @@ def index_put(g, self, indices_list_value, values, accumulate):
def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self)
@ -1647,9 +1648,7 @@ def type_as(g, self, other):
@parse_args("v", "v", "i", "f")
def cosine_similarity(g, x1, x2, dim, eps):
# preserve legacy behavior for Caffe2
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
torch.onnx._CAFFE2_ATEN_FALLBACK:
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
axes_i=[dim], keepdims_i=0)
@ -2599,10 +2598,7 @@ rnn_relu = _one_hidden_rnn("RNN_RELU")
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
# Caffe2-specific op
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
if sym_help.is_caffe2_aten_fallback():
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)

View File

@ -190,7 +190,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_fuse_addmm(graph)
torch._C._jit_pass_lint(graph)
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version, is_caffe2_aten_fallback
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lower_all_tuples(graph)
@ -212,13 +212,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
# Caffe2-specific optimization
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
torch.onnx.symbolic_helper._quantized_ops.clear()
# Unpack quantized weights for conv and linear ops and insert into graph.
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback)
if is_caffe2_aten_fallback:
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback())
if is_caffe2_aten_fallback():
# Insert permutes before and after each conv op to ensure correct order.
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
@ -289,7 +286,7 @@ def warn_on_static_input_change(input_states):
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
# This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX
if operator_export_type is not operator_export_type.ONNX:
if operator_export_type is not operator_export_type.ONNX and torch.onnx._CAFFE2_ATEN_FALLBACK:
if arg_value is True:
warnings.warn("`{}' can be set to True only when 'operator_export_type' is "
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
@ -959,7 +956,8 @@ def _add_attribute(node, key, value, aten):
name, kind = m.group(1), m.group(2)
if _is_onnx_list(value):
kind += "s"
if aten:
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
if aten and is_caffe2_aten_fallback():
if isinstance(value, torch.Tensor):
# Caffe2 proto does not support tensor attribute.
if value.numel() > 1:
@ -1119,14 +1117,13 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
try:
import torch
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
import torch.onnx.symbolic_registry as sym_registry
sym_registry.register_version("", opset_version)
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback and opset_version == 9:
if is_caffe2_aten_fallback() and opset_version == 9:
import torch.onnx.symbolic_caffe2
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
@ -1137,11 +1134,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
else:
ns_op_name = n.kind()
ns, op_name = ns_op_name.split("::")
domain = ns
if ns == "aten":
domain = ""
elif ns == "quantized" and is_caffe2_aten_fallback:
elif ns == "quantized" and is_caffe2_aten_fallback():
domain = "caffe2"
if sym_registry.is_registered_op(op_name, domain, opset_version):
@ -1165,7 +1161,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
outputs = n.outputsSize()
attrs["outputs"] = outputs
return g.at(op_name, *inputs, aten=True, **attrs)
return g.at(op_name, *inputs, **attrs)
else:
raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
except RuntimeError:
@ -1181,6 +1177,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
# Generate an ONNX ATen op node.
def _aten_op(g, operator, *args, overload_name="", **kwargs):
kwargs["aten"] = True
return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs)
@ -1315,7 +1312,7 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
torch._C.Block.op = _block_op # type: ignore[attr-defined]
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]

View File

@ -295,6 +295,14 @@ def skipIfNoQNNPACK(fn):
fn(*args, **kwargs)
return wrapper
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
raise unittest.SkipTest(reason)
else:
fn(*args, **kwargs)
return wrapper
try:
import torchvision # noqa: F401
HAS_TORCHVISION = True

View File

@ -1023,7 +1023,6 @@ def skipIfNoLapack(fn):
fn(*args, **kwargs)
return wrapper
def skipIfNotRegistered(op_name, message):
"""Wraps the decorator to hide the import of the `core`.
@ -1045,6 +1044,17 @@ def skipIfNotRegistered(op_name, message):
skipper = unittest.skip("Cannot import `caffe2.python.core`")
return skipper
def _decide_skip_caffe2(expect_caffe2, reason):
def skip_dec(func):
def wrapper(self):
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
raise unittest.SkipTest(reason)
return func(self)
return wrapper
return skip_dec
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
def skipIfNoSciPy(fn):
@wraps(fn)