mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix ONNX ATen fallback for non-caffe2 engines
This PR introduces 3 BC changes: First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type. Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set. Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion. ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954 Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
This commit is contained in:
committed by
PyTorch MergeBot
parent
b142a224c6
commit
9bbe1d632e
@ -247,7 +247,7 @@ namespace c10 {
|
||||
_(onnx, Less) \
|
||||
_(onnx, LessOrEqual) \
|
||||
_(onnx, Not) \
|
||||
_(onnx, ATen) \
|
||||
_(aten, ATen) \
|
||||
_(onnx, Split) \
|
||||
_(onnx, ConstantOfShape) \
|
||||
_(onnx, Cast) \
|
||||
@ -316,7 +316,8 @@ namespace c10 {
|
||||
_(attr, new_axis) \
|
||||
_(attr, warn_id) \
|
||||
_(attr, allowzero) \
|
||||
_(attr, seen_none)
|
||||
_(attr, seen_none) \
|
||||
_(attr, overload_name)
|
||||
|
||||
enum class _keys : unique_t {
|
||||
#define DEFINE_KEY(ns, s) ns##_##s,
|
||||
|
@ -59,7 +59,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utilitity to generate Caffe2 benchmark models.")
|
||||
description="Utility to generate Caffe2 benchmark models.")
|
||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||
parser.add_argument("-b", "--blob",
|
||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||
|
@ -1943,6 +1943,8 @@ if(BUILD_PYTHON)
|
||||
# ---[ Python.
|
||||
if(BUILD_CAFFE2)
|
||||
add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS})
|
||||
target_compile_definitions(torch PRIVATE BUILD_CAFFE2)
|
||||
target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2)
|
||||
if(USE_NUMPY)
|
||||
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
|
||||
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)
|
||||
|
@ -72,7 +72,7 @@ class Add(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, a, b):
|
||||
return g.op("ATen", a, b, operator_s = "add")
|
||||
return g.at("add", a, b)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
|
@ -179,7 +179,7 @@ private:
|
||||
std::vector<std::string> attrs;
|
||||
for (const auto i : c10::irange(operator_def.arg_size())) {
|
||||
auto & attr = operator_def.arg(i);
|
||||
if(attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name" ) {
|
||||
if (attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name") {
|
||||
continue;
|
||||
}
|
||||
attrs.push_back(attr.name());
|
||||
|
@ -1,9 +1,4 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import core, dyndep
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
@ -38,8 +38,8 @@ torch.onnx.export(MyModule(),
|
||||
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
||||
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
|
||||
# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2)
|
||||
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y)
|
||||
# %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
|
||||
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
|
||||
# return (%4)
|
||||
|
||||
graph = onnx.load(f.name)
|
||||
|
@ -106,7 +106,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utilitity to generate Caffe2 benchmark models.")
|
||||
description="Utility to generate Caffe2 benchmark models.")
|
||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||
parser.add_argument("-b", "--blob",
|
||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||
|
@ -11,7 +11,7 @@ ModelProto {
|
||||
nodes: [
|
||||
Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "ATen", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
]
|
||||
}
|
||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||
|
@ -9,7 +9,7 @@ ModelProto {
|
||||
outputs: [{name: "2", type:Tensor dims: 3 4}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
]
|
||||
}
|
||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||
|
@ -13,7 +13,7 @@ ModelProto {
|
||||
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
|
||||
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
|
||||
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
|
||||
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
|
||||
]
|
||||
}
|
||||
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],
|
||||
|
@ -18,6 +18,7 @@ graph {
|
||||
s: ""
|
||||
type: STRING
|
||||
}
|
||||
domain: "org.pytorch.aten"
|
||||
}
|
||||
name: "torch_jit"
|
||||
input {
|
||||
@ -56,3 +57,7 @@ graph {
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
opset_import {
|
||||
domain: "org.pytorch.aten"
|
||||
version: 1
|
||||
}
|
||||
|
@ -6,19 +6,24 @@ graph {
|
||||
input: "emb.weight"
|
||||
input: "input_1"
|
||||
output: "onnx::Add_3"
|
||||
name: "ATenOp_0"
|
||||
op_type: "ATenOp"
|
||||
name: "ATen_0"
|
||||
op_type: "ATen"
|
||||
attribute {
|
||||
name: "custom_attributes_json"
|
||||
s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
|
||||
type: STRING
|
||||
}
|
||||
attribute {
|
||||
name: "name"
|
||||
s: "aten::embedding"
|
||||
name: "operator"
|
||||
s: "embedding"
|
||||
type: STRING
|
||||
}
|
||||
domain: "com.microsoft"
|
||||
attribute {
|
||||
name: "overload_name"
|
||||
s: ""
|
||||
type: STRING
|
||||
}
|
||||
domain: "org.pytorch.aten"
|
||||
}
|
||||
node {
|
||||
input: "onnx::Add_3"
|
||||
@ -145,27 +150,11 @@ graph {
|
||||
}
|
||||
}
|
||||
}
|
||||
value_info {
|
||||
name: "onnx::Add_3"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_param: "ATenOponnx::Add_3_dim_0"
|
||||
}
|
||||
dim {
|
||||
dim_param: "ATenOponnx::Add_3_dim_1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 12
|
||||
}
|
||||
opset_import {
|
||||
domain: "com.microsoft"
|
||||
domain: "org.pytorch.aten"
|
||||
version: 1
|
||||
}
|
||||
|
@ -2,30 +2,9 @@ ir_version: 7
|
||||
producer_name: "pytorch"
|
||||
producer_version: "CURRENT_VERSION"
|
||||
graph {
|
||||
node {
|
||||
output: "onnx::Cast_3"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 7
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Cast_3"
|
||||
output: "onnx::Loop_4"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 9
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "5"
|
||||
name: "Constant_0"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -40,10 +19,12 @@ graph {
|
||||
node {
|
||||
input: "input"
|
||||
output: "onnx::Gather_6"
|
||||
name: "Shape_1"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "onnx::Gather_7"
|
||||
name: "Constant_2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -58,6 +39,7 @@ graph {
|
||||
input: "onnx::Gather_6"
|
||||
input: "onnx::Gather_7"
|
||||
output: "onnx::Unsqueeze_8"
|
||||
name: "Gather_3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -67,6 +49,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "onnx::Unsqueeze_9"
|
||||
name: "Constant_4"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -82,12 +65,14 @@ graph {
|
||||
input: "onnx::Unsqueeze_8"
|
||||
input: "onnx::Unsqueeze_9"
|
||||
output: "onnx::Concat_10"
|
||||
name: "Unsqueeze_5"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
input: "offsets"
|
||||
input: "onnx::Concat_10"
|
||||
output: "onnx::Slice_11"
|
||||
name: "Concat_6"
|
||||
op_type: "Concat"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -97,6 +82,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "onnx::Slice_12"
|
||||
name: "Constant_7"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -110,6 +96,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "onnx::Slice_13"
|
||||
name: "Constant_8"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -123,6 +110,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "onnx::Slice_14"
|
||||
name: "Constant_9"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -136,6 +124,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "onnx::Slice_15"
|
||||
name: "Constant_10"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -154,15 +143,18 @@ graph {
|
||||
input: "onnx::Slice_12"
|
||||
input: "onnx::Slice_15"
|
||||
output: "onnx::Shape_16"
|
||||
name: "Slice_11"
|
||||
op_type: "Slice"
|
||||
}
|
||||
node {
|
||||
input: "onnx::Shape_16"
|
||||
output: "onnx::Gather_17"
|
||||
name: "Shape_12"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "onnx::Gather_18"
|
||||
name: "Constant_13"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -177,6 +169,7 @@ graph {
|
||||
input: "onnx::Gather_17"
|
||||
input: "onnx::Gather_18"
|
||||
output: "onnx::Loop_19"
|
||||
name: "Gather_14"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -186,8 +179,9 @@ graph {
|
||||
}
|
||||
node {
|
||||
input: "onnx::Loop_19"
|
||||
input: "onnx::Loop_4"
|
||||
input: "onnx::Loop_33"
|
||||
output: "20"
|
||||
name: "Loop_15"
|
||||
op_type: "Loop"
|
||||
attribute {
|
||||
name: "body"
|
||||
@ -196,7 +190,7 @@ graph {
|
||||
input: "onnx::Slice_11"
|
||||
input: "21"
|
||||
output: "23"
|
||||
name: "Gather_0"
|
||||
name: "Gather_16"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -208,7 +202,7 @@ graph {
|
||||
input: "onnx::Shape_16"
|
||||
input: "21"
|
||||
output: "24"
|
||||
name: "Gather_1"
|
||||
name: "Gather_17"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -218,7 +212,7 @@ graph {
|
||||
}
|
||||
node {
|
||||
output: "25"
|
||||
name: "Constant_2"
|
||||
name: "Constant_18"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -234,12 +228,12 @@ graph {
|
||||
input: "23"
|
||||
input: "25"
|
||||
output: "26"
|
||||
name: "Unsqueeze_3"
|
||||
name: "Unsqueeze_19"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
output: "27"
|
||||
name: "Constant_4"
|
||||
name: "Constant_20"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -255,7 +249,7 @@ graph {
|
||||
input: "24"
|
||||
input: "27"
|
||||
output: "28"
|
||||
name: "Unsqueeze_5"
|
||||
name: "Unsqueeze_21"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
@ -264,14 +258,14 @@ graph {
|
||||
input: "28"
|
||||
input: "5"
|
||||
output: "29"
|
||||
name: "Slice_6"
|
||||
name: "Slice_22"
|
||||
op_type: "Slice"
|
||||
}
|
||||
node {
|
||||
input: "weight"
|
||||
input: "29"
|
||||
output: "30"
|
||||
name: "Gather_7"
|
||||
name: "Gather_23"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
@ -282,7 +276,7 @@ graph {
|
||||
node {
|
||||
input: "30"
|
||||
output: "31"
|
||||
name: "ReduceMean_8"
|
||||
name: "ReduceMean_24"
|
||||
op_type: "ReduceMean"
|
||||
attribute {
|
||||
name: "axes"
|
||||
@ -296,9 +290,9 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Loop_4"
|
||||
input: "onnx::Loop_33"
|
||||
output: "32"
|
||||
name: "Cast_9"
|
||||
name: "Cast_25"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
@ -362,6 +356,11 @@ graph {
|
||||
name: "weight"
|
||||
raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?"
|
||||
}
|
||||
initializer {
|
||||
data_type: 9
|
||||
name: "onnx::Loop_33"
|
||||
raw_data: "\001"
|
||||
}
|
||||
input {
|
||||
name: "input"
|
||||
type {
|
||||
@ -404,6 +403,16 @@ graph {
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "onnx::Loop_33"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "20"
|
||||
type {
|
||||
|
@ -7,6 +7,7 @@ graph {
|
||||
input: "weight"
|
||||
input: "bias"
|
||||
output: "3"
|
||||
name: "ATen_0"
|
||||
op_type: "ATen"
|
||||
attribute {
|
||||
name: "cudnn_enable"
|
||||
@ -34,6 +35,7 @@ graph {
|
||||
s: ""
|
||||
type: STRING
|
||||
}
|
||||
domain: "org.pytorch.aten"
|
||||
}
|
||||
name: "torch_jit"
|
||||
initializer {
|
||||
@ -130,3 +132,7 @@ graph {
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
opset_import {
|
||||
domain: "org.pytorch.aten"
|
||||
version: 1
|
||||
}
|
||||
|
@ -20,12 +20,15 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing._internal.common_utils import skipIfCaffe2
|
||||
|
||||
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
|
||||
--no-onnx: no onnx python dependence
|
||||
--produce-onnx-test-data: generate onnx test data
|
||||
--accept: accept onnx updates and overwrite models
|
||||
'''
|
||||
import unittest
|
||||
unittest.TestCase.maxDiff = None
|
||||
|
||||
_onnx_test = False # flag to produce onnx test cases.
|
||||
_onnx_dep = True # flag to import onnx package.
|
||||
@ -322,6 +325,7 @@ class TestOperators(TestCase):
|
||||
x = torch.randn(20, 16, 50)
|
||||
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_at_op(self):
|
||||
x = torch.randn(3, 4)
|
||||
|
||||
@ -339,7 +343,8 @@ class TestOperators(TestCase):
|
||||
def forward(self, x):
|
||||
return MyFun.apply(x)
|
||||
|
||||
self.assertONNX(MyModule(), x)
|
||||
self.assertONNX(MyModule(), x,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
|
||||
def test_clip(self):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
@ -588,6 +593,7 @@ class TestOperators(TestCase):
|
||||
self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x,
|
||||
keep_initializers_as_inputs=True)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_embedding_bags(self):
|
||||
emb_bag = nn.EmbeddingBag(10, 8)
|
||||
input = torch.tensor([1, 2, 3, 4]).long()
|
||||
@ -787,6 +793,7 @@ class TestOperators(TestCase):
|
||||
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
|
||||
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_layer_norm_aten(self):
|
||||
model = torch.nn.LayerNorm([10, 10])
|
||||
x = torch.randn(20, 5, 10, 10)
|
||||
@ -954,7 +961,7 @@ class TestOperators(TestCase):
|
||||
f'"sparse":{str(sparse).lower()}'
|
||||
'}'
|
||||
)
|
||||
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
|
||||
output = g.at("embedding", weight, indices,
|
||||
custom_attributes_json_s=custom_attributes_json)
|
||||
return output
|
||||
|
||||
@ -978,6 +985,7 @@ class TestOperators(TestCase):
|
||||
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
||||
|
||||
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
|
||||
@skipIfCaffe2
|
||||
def test_aten_embedding_2(self):
|
||||
_onnx_opset_version = 12
|
||||
|
||||
@ -990,7 +998,7 @@ class TestOperators(TestCase):
|
||||
f'"sparse":{str(sparse).lower()}'
|
||||
'}'
|
||||
)
|
||||
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
|
||||
output = g.at("embedding", weight, indices,
|
||||
custom_attributes_json_s=custom_attributes_json)
|
||||
|
||||
# do shape inference and set it via setType
|
||||
@ -1016,7 +1024,9 @@ class TestOperators(TestCase):
|
||||
x = torch.ones(32, dtype=torch.long)
|
||||
y = torch.randn(1, 8)
|
||||
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
|
||||
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}})
|
||||
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}},
|
||||
keep_initializers_as_inputs=False,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
|
||||
unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
|
||||
|
||||
|
@ -62,6 +62,8 @@ from torch.testing._internal.common_quantized import (
|
||||
override_qengines,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import skipIfNoCaffe2
|
||||
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
import torch.testing._internal.hypothesis_utils as hu
|
||||
@ -1464,6 +1466,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
|
||||
onnx_model = export_to_onnx(model, data, input_names)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
@skipIfNoCaffe2
|
||||
def test_lower_graph_linear(self):
|
||||
model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
|
||||
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
|
||||
@ -1471,6 +1474,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
|
||||
self._test_lower_graph_impl(model, data)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
@skipIfNoCaffe2
|
||||
def test_lower_graph_conv2d(self):
|
||||
model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
|
||||
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
|
||||
|
@ -362,7 +362,6 @@ if(USE_NUMPY)
|
||||
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
||||
endif()
|
||||
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2)
|
||||
if(HAVE_SOVERSION)
|
||||
set_target_properties(torch_python PROPERTIES
|
||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||
|
@ -1905,7 +1905,8 @@ static std::unordered_set<std::string> nodeTypeReliableForTracer = {
|
||||
"onnx::Cast",
|
||||
"onnx::Constant",
|
||||
"onnx::Relu",
|
||||
"com.microsoft::Gelu"};
|
||||
"com.microsoft::Gelu",
|
||||
"aten::ATen"};
|
||||
|
||||
void UpdateReliable(
|
||||
torch::jit::Value* output,
|
||||
|
@ -113,7 +113,7 @@ void validateBlock(
|
||||
WithInsertPoint guard(node);
|
||||
auto* new_node =
|
||||
b->owningGraph()->insertNode(b->owningGraph()->create(
|
||||
Symbol(::c10::onnx::ATen),
|
||||
Symbol(::c10::aten::ATen),
|
||||
node->inputs(),
|
||||
node->outputs().size()));
|
||||
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
||||
@ -1163,8 +1163,8 @@ void GraphEncoder::EncodeIntermediateValueInfo(
|
||||
const Value* v) {
|
||||
// Motivation is to encode ValueInfo for onnx local function nodes.
|
||||
auto n = v->node();
|
||||
if (n->kind().is_onnx()) {
|
||||
// Encode value info only for non-onnx nodes.
|
||||
if (n->kind().is_onnx() || n->kind().is_aten()) {
|
||||
// Encode value info only for non-onnx or non-ATen nodes.
|
||||
return;
|
||||
}
|
||||
if (n->owningGraph() != graph_.get()) {
|
||||
|
@ -329,6 +329,10 @@ def _is_scalar_list(x):
|
||||
element_type in scalar_name_to_pytorch.keys() and \
|
||||
(scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys())
|
||||
|
||||
def is_caffe2_aten_fallback():
|
||||
return (_operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||
|
||||
def _get_tensor_rank(x):
|
||||
if not _is_tensor(x) or x.type() is None:
|
||||
return None
|
||||
|
@ -551,10 +551,7 @@ def arange(g, *args):
|
||||
def _dim_arange(g, like, dim):
|
||||
like_shape = g.op("Shape", like)
|
||||
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
||||
# Caffe2-specific op
|
||||
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||
if is_caffe2_aten_fallback:
|
||||
if sym_help.is_caffe2_aten_fallback():
|
||||
return g.op("_caffe2::Range", stop)
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
|
||||
@ -643,7 +640,8 @@ def index(g, self, index):
|
||||
def index_fill(g, self, dim, index, value):
|
||||
dim_value = sym_help._parse_arg(dim, "i")
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
|
||||
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
|
||||
|
||||
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
||||
value = sym_help._maybe_get_scalar(value)
|
||||
value = sym_help._if_scalar_type_as(g, value, self)
|
||||
|
@ -565,7 +565,7 @@ def transpose(g, self, dim0, dim1):
|
||||
# if we don't have dim information we cannot
|
||||
# output a permute so use ATen instead
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.at("transpose", self, dim0_i=dim0, dim1_i=dim1, overload_name="int")
|
||||
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
|
||||
else:
|
||||
raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
|
||||
"of unknown rank.")
|
||||
@ -1581,7 +1581,8 @@ def index_put(g, self, indices_list_value, values, accumulate):
|
||||
def index_fill(g, self, dim, index, value):
|
||||
dim_value = sym_help._parse_arg(dim, "i")
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar")
|
||||
return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
|
||||
|
||||
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
|
||||
value = sym_help._maybe_get_scalar(value)
|
||||
value = sym_help._if_scalar_type_as(g, value, self)
|
||||
@ -1647,9 +1648,7 @@ def type_as(g, self, other):
|
||||
|
||||
@parse_args("v", "v", "i", "f")
|
||||
def cosine_similarity(g, x1, x2, dim, eps):
|
||||
# preserve legacy behavior for Caffe2
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
|
||||
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
|
||||
axes_i=[dim], keepdims_i=0)
|
||||
@ -2599,10 +2598,7 @@ rnn_relu = _one_hidden_rnn("RNN_RELU")
|
||||
def _dim_arange(g, like, dim):
|
||||
like_shape = g.op("Shape", like)
|
||||
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
||||
# Caffe2-specific op
|
||||
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||
if is_caffe2_aten_fallback:
|
||||
if sym_help.is_caffe2_aten_fallback():
|
||||
return g.op("_caffe2::Range", stop)
|
||||
else:
|
||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
|
@ -190,7 +190,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
||||
torch._C._jit_pass_peephole(graph, True)
|
||||
torch._C._jit_pass_fuse_addmm(graph)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version, is_caffe2_aten_fallback
|
||||
|
||||
torch._C._jit_pass_peephole(graph, True)
|
||||
torch._C._jit_pass_lower_all_tuples(graph)
|
||||
@ -212,13 +212,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
||||
torch._C._jit_pass_onnx_remove_print(graph)
|
||||
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
|
||||
|
||||
# Caffe2-specific optimization
|
||||
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||
torch.onnx.symbolic_helper._quantized_ops.clear()
|
||||
# Unpack quantized weights for conv and linear ops and insert into graph.
|
||||
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback)
|
||||
if is_caffe2_aten_fallback:
|
||||
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback())
|
||||
if is_caffe2_aten_fallback():
|
||||
# Insert permutes before and after each conv op to ensure correct order.
|
||||
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
|
||||
|
||||
@ -289,7 +286,7 @@ def warn_on_static_input_change(input_states):
|
||||
|
||||
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||
# This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX
|
||||
if operator_export_type is not operator_export_type.ONNX:
|
||||
if operator_export_type is not operator_export_type.ONNX and torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||
if arg_value is True:
|
||||
warnings.warn("`{}' can be set to True only when 'operator_export_type' is "
|
||||
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
|
||||
@ -959,7 +956,8 @@ def _add_attribute(node, key, value, aten):
|
||||
name, kind = m.group(1), m.group(2)
|
||||
if _is_onnx_list(value):
|
||||
kind += "s"
|
||||
if aten:
|
||||
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
|
||||
if aten and is_caffe2_aten_fallback():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Caffe2 proto does not support tensor attribute.
|
||||
if value.numel() > 1:
|
||||
@ -1119,14 +1117,13 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
||||
try:
|
||||
import torch
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
|
||||
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
|
||||
import torch.onnx.symbolic_registry as sym_registry
|
||||
|
||||
sym_registry.register_version("", opset_version)
|
||||
|
||||
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
||||
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
|
||||
torch.onnx._CAFFE2_ATEN_FALLBACK)
|
||||
if is_caffe2_aten_fallback and opset_version == 9:
|
||||
if is_caffe2_aten_fallback() and opset_version == 9:
|
||||
import torch.onnx.symbolic_caffe2
|
||||
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
||||
|
||||
@ -1137,11 +1134,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
||||
else:
|
||||
ns_op_name = n.kind()
|
||||
ns, op_name = ns_op_name.split("::")
|
||||
|
||||
domain = ns
|
||||
if ns == "aten":
|
||||
domain = ""
|
||||
elif ns == "quantized" and is_caffe2_aten_fallback:
|
||||
elif ns == "quantized" and is_caffe2_aten_fallback():
|
||||
domain = "caffe2"
|
||||
|
||||
if sym_registry.is_registered_op(op_name, domain, opset_version):
|
||||
@ -1165,7 +1161,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
||||
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
|
||||
outputs = n.outputsSize()
|
||||
attrs["outputs"] = outputs
|
||||
return g.at(op_name, *inputs, aten=True, **attrs)
|
||||
return g.at(op_name, *inputs, **attrs)
|
||||
else:
|
||||
raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
|
||||
except RuntimeError:
|
||||
@ -1181,6 +1177,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
|
||||
|
||||
# Generate an ONNX ATen op node.
|
||||
def _aten_op(g, operator, *args, overload_name="", **kwargs):
|
||||
kwargs["aten"] = True
|
||||
return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs)
|
||||
|
||||
|
||||
@ -1315,7 +1312,7 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
||||
|
||||
|
||||
torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
|
||||
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
|
||||
torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
|
||||
torch._C.Block.op = _block_op # type: ignore[attr-defined]
|
||||
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
|
||||
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]
|
||||
|
@ -295,6 +295,14 @@ def skipIfNoQNNPACK(fn):
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||
raise unittest.SkipTest(reason)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
try:
|
||||
import torchvision # noqa: F401
|
||||
HAS_TORCHVISION = True
|
||||
|
@ -1023,7 +1023,6 @@ def skipIfNoLapack(fn):
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def skipIfNotRegistered(op_name, message):
|
||||
"""Wraps the decorator to hide the import of the `core`.
|
||||
|
||||
@ -1045,6 +1044,17 @@ def skipIfNotRegistered(op_name, message):
|
||||
skipper = unittest.skip("Cannot import `caffe2.python.core`")
|
||||
return skipper
|
||||
|
||||
def _decide_skip_caffe2(expect_caffe2, reason):
|
||||
def skip_dec(func):
|
||||
def wrapper(self):
|
||||
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
|
||||
raise unittest.SkipTest(reason)
|
||||
return func(self)
|
||||
return wrapper
|
||||
return skip_dec
|
||||
|
||||
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
|
||||
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
|
||||
|
||||
def skipIfNoSciPy(fn):
|
||||
@wraps(fn)
|
||||
|
Reference in New Issue
Block a user