Fix ONNX ATen fallback for non-caffe2 engines

This PR introduces 3 BC changes:

First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type.

Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set.

Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion.
ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954
Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
This commit is contained in:
Thiago Crepaldi
2022-04-14 23:18:45 +00:00
committed by PyTorch MergeBot
parent b142a224c6
commit 9bbe1d632e
26 changed files with 146 additions and 112 deletions

View File

@ -247,7 +247,7 @@ namespace c10 {
_(onnx, Less) \ _(onnx, Less) \
_(onnx, LessOrEqual) \ _(onnx, LessOrEqual) \
_(onnx, Not) \ _(onnx, Not) \
_(onnx, ATen) \ _(aten, ATen) \
_(onnx, Split) \ _(onnx, Split) \
_(onnx, ConstantOfShape) \ _(onnx, ConstantOfShape) \
_(onnx, Cast) \ _(onnx, Cast) \
@ -316,7 +316,8 @@ namespace c10 {
_(attr, new_axis) \ _(attr, new_axis) \
_(attr, warn_id) \ _(attr, warn_id) \
_(attr, allowzero) \ _(attr, allowzero) \
_(attr, seen_none) _(attr, seen_none) \
_(attr, overload_name)
enum class _keys : unique_t { enum class _keys : unique_t {
#define DEFINE_KEY(ns, s) ns##_##s, #define DEFINE_KEY(ns, s) ns##_##s,

View File

@ -59,7 +59,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Utilitity to generate Caffe2 benchmark models.") description="Utility to generate Caffe2 benchmark models.")
parser.add_argument("operator", help="Caffe2 operator to benchmark.") parser.add_argument("operator", help="Caffe2 operator to benchmark.")
parser.add_argument("-b", "--blob", parser.add_argument("-b", "--blob",
help="Instantiate a blob --blob name=dim1,dim2,dim3", help="Instantiate a blob --blob name=dim1,dim2,dim3",

View File

@ -1943,6 +1943,8 @@ if(BUILD_PYTHON)
# ---[ Python. # ---[ Python.
if(BUILD_CAFFE2) if(BUILD_CAFFE2)
add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS}) add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS})
target_compile_definitions(torch PRIVATE BUILD_CAFFE2)
target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2)
if(USE_NUMPY) if(USE_NUMPY)
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY") target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy) target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)

View File

@ -72,7 +72,7 @@ class Add(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(g, a, b): def symbolic(g, a, b):
return g.op("ATen", a, b, operator_s = "add") return g.at("add", a, b)
@staticmethod @staticmethod
def forward(ctx, a, b): def forward(ctx, a, b):

View File

@ -179,7 +179,7 @@ private:
std::vector<std::string> attrs; std::vector<std::string> attrs;
for (const auto i : c10::irange(operator_def.arg_size())) { for (const auto i : c10::irange(operator_def.arg_size())) {
auto & attr = operator_def.arg(i); auto & attr = operator_def.arg(i);
if(attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name" ) { if (attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name") {
continue; continue;
} }
attrs.push_back(attr.name()); attrs.push_back(attr.name());

View File

@ -1,9 +1,4 @@
from caffe2.python import core
from caffe2.python import core, dyndep
from hypothesis import given from hypothesis import given
import caffe2.python.hypothesis_test_util as hu import caffe2.python.hypothesis_test_util as hu

View File

@ -38,8 +38,8 @@ torch.onnx.export(MyModule(),
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu), # graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): # %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input) # %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2) # %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y) # %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
# return (%4) # return (%4)
graph = onnx.load(f.name) graph = onnx.load(f.name)

View File

@ -106,7 +106,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Utilitity to generate Caffe2 benchmark models.") description="Utility to generate Caffe2 benchmark models.")
parser.add_argument("operator", help="Caffe2 operator to benchmark.") parser.add_argument("operator", help="Caffe2 operator to benchmark.")
parser.add_argument("-b", "--blob", parser.add_argument("-b", "--blob",
help="Instantiate a blob --blob name=dim1,dim2,dim3", help="Instantiate a blob --blob name=dim1,dim2,dim3",

View File

@ -11,7 +11,7 @@ ModelProto {
nodes: [ nodes: [
Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "ATen", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]} Node {type: "ATen", domain: "org.pytorch.aten", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]}
] ]
} }
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -9,7 +9,7 @@ ModelProto {
outputs: [{name: "2", type:Tensor dims: 3 4}] outputs: [{name: "2", type:Tensor dims: 3 4}]
initializers: [] initializers: []
nodes: [ nodes: [
Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]} Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]}
] ]
} }
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -13,7 +13,7 @@ ModelProto {
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]}, Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]}, Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]},
Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]} Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]}
] ]
} }
opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}],

View File

@ -18,6 +18,7 @@ graph {
s: "" s: ""
type: STRING type: STRING
} }
domain: "org.pytorch.aten"
} }
name: "torch_jit" name: "torch_jit"
input { input {
@ -56,3 +57,7 @@ graph {
opset_import { opset_import {
version: 13 version: 13
} }
opset_import {
domain: "org.pytorch.aten"
version: 1
}

View File

@ -6,19 +6,24 @@ graph {
input: "emb.weight" input: "emb.weight"
input: "input_1" input: "input_1"
output: "onnx::Add_3" output: "onnx::Add_3"
name: "ATenOp_0" name: "ATen_0"
op_type: "ATenOp" op_type: "ATen"
attribute { attribute {
name: "custom_attributes_json" name: "custom_attributes_json"
s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}" s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
type: STRING type: STRING
} }
attribute { attribute {
name: "name" name: "operator"
s: "aten::embedding" s: "embedding"
type: STRING type: STRING
} }
domain: "com.microsoft" attribute {
name: "overload_name"
s: ""
type: STRING
}
domain: "org.pytorch.aten"
} }
node { node {
input: "onnx::Add_3" input: "onnx::Add_3"
@ -145,27 +150,11 @@ graph {
} }
} }
} }
value_info {
name: "onnx::Add_3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "ATenOponnx::Add_3_dim_0"
}
dim {
dim_param: "ATenOponnx::Add_3_dim_1"
}
}
}
}
}
} }
opset_import { opset_import {
version: 12 version: 12
} }
opset_import { opset_import {
domain: "com.microsoft" domain: "org.pytorch.aten"
version: 1 version: 1
} }

View File

@ -2,30 +2,9 @@ ir_version: 7
producer_name: "pytorch" producer_name: "pytorch"
producer_version: "CURRENT_VERSION" producer_version: "CURRENT_VERSION"
graph { graph {
node {
output: "onnx::Cast_3"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Cast_3"
output: "onnx::Loop_4"
op_type: "Cast"
attribute {
name: "to"
i: 9
type: INT
}
}
node { node {
output: "5" output: "5"
name: "Constant_0"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -40,10 +19,12 @@ graph {
node { node {
input: "input" input: "input"
output: "onnx::Gather_6" output: "onnx::Gather_6"
name: "Shape_1"
op_type: "Shape" op_type: "Shape"
} }
node { node {
output: "onnx::Gather_7" output: "onnx::Gather_7"
name: "Constant_2"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -58,6 +39,7 @@ graph {
input: "onnx::Gather_6" input: "onnx::Gather_6"
input: "onnx::Gather_7" input: "onnx::Gather_7"
output: "onnx::Unsqueeze_8" output: "onnx::Unsqueeze_8"
name: "Gather_3"
op_type: "Gather" op_type: "Gather"
attribute { attribute {
name: "axis" name: "axis"
@ -67,6 +49,7 @@ graph {
} }
node { node {
output: "onnx::Unsqueeze_9" output: "onnx::Unsqueeze_9"
name: "Constant_4"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -82,12 +65,14 @@ graph {
input: "onnx::Unsqueeze_8" input: "onnx::Unsqueeze_8"
input: "onnx::Unsqueeze_9" input: "onnx::Unsqueeze_9"
output: "onnx::Concat_10" output: "onnx::Concat_10"
name: "Unsqueeze_5"
op_type: "Unsqueeze" op_type: "Unsqueeze"
} }
node { node {
input: "offsets" input: "offsets"
input: "onnx::Concat_10" input: "onnx::Concat_10"
output: "onnx::Slice_11" output: "onnx::Slice_11"
name: "Concat_6"
op_type: "Concat" op_type: "Concat"
attribute { attribute {
name: "axis" name: "axis"
@ -97,6 +82,7 @@ graph {
} }
node { node {
output: "onnx::Slice_12" output: "onnx::Slice_12"
name: "Constant_7"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -110,6 +96,7 @@ graph {
} }
node { node {
output: "onnx::Slice_13" output: "onnx::Slice_13"
name: "Constant_8"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -123,6 +110,7 @@ graph {
} }
node { node {
output: "onnx::Slice_14" output: "onnx::Slice_14"
name: "Constant_9"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -136,6 +124,7 @@ graph {
} }
node { node {
output: "onnx::Slice_15" output: "onnx::Slice_15"
name: "Constant_10"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -154,15 +143,18 @@ graph {
input: "onnx::Slice_12" input: "onnx::Slice_12"
input: "onnx::Slice_15" input: "onnx::Slice_15"
output: "onnx::Shape_16" output: "onnx::Shape_16"
name: "Slice_11"
op_type: "Slice" op_type: "Slice"
} }
node { node {
input: "onnx::Shape_16" input: "onnx::Shape_16"
output: "onnx::Gather_17" output: "onnx::Gather_17"
name: "Shape_12"
op_type: "Shape" op_type: "Shape"
} }
node { node {
output: "onnx::Gather_18" output: "onnx::Gather_18"
name: "Constant_13"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -177,6 +169,7 @@ graph {
input: "onnx::Gather_17" input: "onnx::Gather_17"
input: "onnx::Gather_18" input: "onnx::Gather_18"
output: "onnx::Loop_19" output: "onnx::Loop_19"
name: "Gather_14"
op_type: "Gather" op_type: "Gather"
attribute { attribute {
name: "axis" name: "axis"
@ -186,8 +179,9 @@ graph {
} }
node { node {
input: "onnx::Loop_19" input: "onnx::Loop_19"
input: "onnx::Loop_4" input: "onnx::Loop_33"
output: "20" output: "20"
name: "Loop_15"
op_type: "Loop" op_type: "Loop"
attribute { attribute {
name: "body" name: "body"
@ -196,7 +190,7 @@ graph {
input: "onnx::Slice_11" input: "onnx::Slice_11"
input: "21" input: "21"
output: "23" output: "23"
name: "Gather_0" name: "Gather_16"
op_type: "Gather" op_type: "Gather"
attribute { attribute {
name: "axis" name: "axis"
@ -208,7 +202,7 @@ graph {
input: "onnx::Shape_16" input: "onnx::Shape_16"
input: "21" input: "21"
output: "24" output: "24"
name: "Gather_1" name: "Gather_17"
op_type: "Gather" op_type: "Gather"
attribute { attribute {
name: "axis" name: "axis"
@ -218,7 +212,7 @@ graph {
} }
node { node {
output: "25" output: "25"
name: "Constant_2" name: "Constant_18"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -234,12 +228,12 @@ graph {
input: "23" input: "23"
input: "25" input: "25"
output: "26" output: "26"
name: "Unsqueeze_3" name: "Unsqueeze_19"
op_type: "Unsqueeze" op_type: "Unsqueeze"
} }
node { node {
output: "27" output: "27"
name: "Constant_4" name: "Constant_20"
op_type: "Constant" op_type: "Constant"
attribute { attribute {
name: "value" name: "value"
@ -255,7 +249,7 @@ graph {
input: "24" input: "24"
input: "27" input: "27"
output: "28" output: "28"
name: "Unsqueeze_5" name: "Unsqueeze_21"
op_type: "Unsqueeze" op_type: "Unsqueeze"
} }
node { node {
@ -264,14 +258,14 @@ graph {
input: "28" input: "28"
input: "5" input: "5"
output: "29" output: "29"
name: "Slice_6" name: "Slice_22"
op_type: "Slice" op_type: "Slice"
} }
node { node {
input: "weight" input: "weight"
input: "29" input: "29"
output: "30" output: "30"
name: "Gather_7" name: "Gather_23"
op_type: "Gather" op_type: "Gather"
attribute { attribute {
name: "axis" name: "axis"
@ -282,7 +276,7 @@ graph {
node { node {
input: "30" input: "30"
output: "31" output: "31"
name: "ReduceMean_8" name: "ReduceMean_24"
op_type: "ReduceMean" op_type: "ReduceMean"
attribute { attribute {
name: "axes" name: "axes"
@ -296,9 +290,9 @@ graph {
} }
} }
node { node {
input: "onnx::Loop_4" input: "onnx::Loop_33"
output: "32" output: "32"
name: "Cast_9" name: "Cast_25"
op_type: "Cast" op_type: "Cast"
attribute { attribute {
name: "to" name: "to"
@ -362,6 +356,11 @@ graph {
name: "weight" name: "weight"
raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?" raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000<R\262\240\276\343\016\224\2779\241\353?8;\202\277\023\020\234?E\370#>\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274<?t\375l?\342\270l?\240\352:>\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?"
} }
initializer {
data_type: 9
name: "onnx::Loop_33"
raw_data: "\001"
}
input { input {
name: "input" name: "input"
type { type {
@ -404,6 +403,16 @@ graph {
} }
} }
} }
input {
name: "onnx::Loop_33"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
output { output {
name: "20" name: "20"
type { type {

View File

@ -7,6 +7,7 @@ graph {
input: "weight" input: "weight"
input: "bias" input: "bias"
output: "3" output: "3"
name: "ATen_0"
op_type: "ATen" op_type: "ATen"
attribute { attribute {
name: "cudnn_enable" name: "cudnn_enable"
@ -34,6 +35,7 @@ graph {
s: "" s: ""
type: STRING type: STRING
} }
domain: "org.pytorch.aten"
} }
name: "torch_jit" name: "torch_jit"
initializer { initializer {
@ -130,3 +132,7 @@ graph {
opset_import { opset_import {
version: 13 version: 13
} }
opset_import {
domain: "org.pytorch.aten"
version: 1
}

View File

@ -20,12 +20,15 @@ import os
import shutil import shutil
import tempfile import tempfile
import torch.testing._internal.common_utils as common import torch.testing._internal.common_utils as common
from torch.testing._internal.common_utils import skipIfCaffe2
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data] '''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--no-onnx: no onnx python dependence --no-onnx: no onnx python dependence
--produce-onnx-test-data: generate onnx test data --produce-onnx-test-data: generate onnx test data
--accept: accept onnx updates and overwrite models --accept: accept onnx updates and overwrite models
''' '''
import unittest
unittest.TestCase.maxDiff = None
_onnx_test = False # flag to produce onnx test cases. _onnx_test = False # flag to produce onnx test cases.
_onnx_dep = True # flag to import onnx package. _onnx_dep = True # flag to import onnx package.
@ -322,6 +325,7 @@ class TestOperators(TestCase):
x = torch.randn(20, 16, 50) x = torch.randn(20, 16, 50)
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
@skipIfCaffe2
def test_at_op(self): def test_at_op(self):
x = torch.randn(3, 4) x = torch.randn(3, 4)
@ -339,7 +343,8 @@ class TestOperators(TestCase):
def forward(self, x): def forward(self, x):
return MyFun.apply(x) return MyFun.apply(x)
self.assertONNX(MyModule(), x) self.assertONNX(MyModule(), x,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
def test_clip(self): def test_clip(self):
x = torch.randn(3, 4, requires_grad=True) x = torch.randn(3, 4, requires_grad=True)
@ -588,6 +593,7 @@ class TestOperators(TestCase):
self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x, self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x,
keep_initializers_as_inputs=True) keep_initializers_as_inputs=True)
@skipIfCaffe2
def test_embedding_bags(self): def test_embedding_bags(self):
emb_bag = nn.EmbeddingBag(10, 8) emb_bag = nn.EmbeddingBag(10, 8)
input = torch.tensor([1, 2, 3, 4]).long() input = torch.tensor([1, 2, 3, 4]).long()
@ -787,6 +793,7 @@ class TestOperators(TestCase):
input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2)
self.assertONNX(BitshiftModel(), (input, input2), opset_version=11) self.assertONNX(BitshiftModel(), (input, input2), opset_version=11)
@skipIfCaffe2
def test_layer_norm_aten(self): def test_layer_norm_aten(self):
model = torch.nn.LayerNorm([10, 10]) model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10) x = torch.randn(20, 5, 10, 10)
@ -954,7 +961,7 @@ class TestOperators(TestCase):
f'"sparse":{str(sparse).lower()}' f'"sparse":{str(sparse).lower()}'
'}' '}'
) )
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', output = g.at("embedding", weight, indices,
custom_attributes_json_s=custom_attributes_json) custom_attributes_json_s=custom_attributes_json)
return output return output
@ -978,6 +985,7 @@ class TestOperators(TestCase):
unregister_custom_op_symbolic('::embedding', _onnx_opset_version) unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
@skipIfCaffe2
def test_aten_embedding_2(self): def test_aten_embedding_2(self):
_onnx_opset_version = 12 _onnx_opset_version = 12
@ -990,7 +998,7 @@ class TestOperators(TestCase):
f'"sparse":{str(sparse).lower()}' f'"sparse":{str(sparse).lower()}'
'}' '}'
) )
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', output = g.at("embedding", weight, indices,
custom_attributes_json_s=custom_attributes_json) custom_attributes_json_s=custom_attributes_json)
# do shape inference and set it via setType # do shape inference and set it via setType
@ -1016,7 +1024,9 @@ class TestOperators(TestCase):
x = torch.ones(32, dtype=torch.long) x = torch.ones(32, dtype=torch.long)
y = torch.randn(1, 8) y = torch.randn(1, 8)
self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'], self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}}) dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}},
keep_initializers_as_inputs=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
unregister_custom_op_symbolic('::embedding', _onnx_opset_version) unregister_custom_op_symbolic('::embedding', _onnx_opset_version)

View File

@ -62,6 +62,8 @@ from torch.testing._internal.common_quantized import (
override_qengines, override_qengines,
) )
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoCaffe2
from hypothesis import given from hypothesis import given
from hypothesis import strategies as st from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu import torch.testing._internal.hypothesis_utils as hu
@ -1464,6 +1466,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
onnx_model = export_to_onnx(model, data, input_names) onnx_model = export_to_onnx(model, data, input_names)
@skipIfNoFBGEMM @skipIfNoFBGEMM
@skipIfNoCaffe2
def test_lower_graph_linear(self): def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float) model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32) data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
@ -1471,6 +1474,7 @@ class TestQuantizeEagerONNXExport(JitTestCase):
self._test_lower_graph_impl(model, data) self._test_lower_graph_impl(model, data)
@skipIfNoFBGEMM @skipIfNoFBGEMM
@skipIfNoCaffe2
def test_lower_graph_conv2d(self): def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float) model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)

View File

@ -362,7 +362,6 @@ if(USE_NUMPY)
target_compile_definitions(torch_python PRIVATE USE_NUMPY) target_compile_definitions(torch_python PRIVATE USE_NUMPY)
endif() endif()
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2)
if(HAVE_SOVERSION) if(HAVE_SOVERSION)
set_target_properties(torch_python PROPERTIES set_target_properties(torch_python PROPERTIES
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})

View File

@ -1905,7 +1905,8 @@ static std::unordered_set<std::string> nodeTypeReliableForTracer = {
"onnx::Cast", "onnx::Cast",
"onnx::Constant", "onnx::Constant",
"onnx::Relu", "onnx::Relu",
"com.microsoft::Gelu"}; "com.microsoft::Gelu",
"aten::ATen"};
void UpdateReliable( void UpdateReliable(
torch::jit::Value* output, torch::jit::Value* output,

View File

@ -113,7 +113,7 @@ void validateBlock(
WithInsertPoint guard(node); WithInsertPoint guard(node);
auto* new_node = auto* new_node =
b->owningGraph()->insertNode(b->owningGraph()->create( b->owningGraph()->insertNode(b->owningGraph()->create(
Symbol(::c10::onnx::ATen), Symbol(::c10::aten::ATen),
node->inputs(), node->inputs(),
node->outputs().size())); node->outputs().size()));
for (size_t i = 0; i < node->outputs().size(); ++i) { for (size_t i = 0; i < node->outputs().size(); ++i) {
@ -1163,8 +1163,8 @@ void GraphEncoder::EncodeIntermediateValueInfo(
const Value* v) { const Value* v) {
// Motivation is to encode ValueInfo for onnx local function nodes. // Motivation is to encode ValueInfo for onnx local function nodes.
auto n = v->node(); auto n = v->node();
if (n->kind().is_onnx()) { if (n->kind().is_onnx() || n->kind().is_aten()) {
// Encode value info only for non-onnx nodes. // Encode value info only for non-onnx or non-ATen nodes.
return; return;
} }
if (n->owningGraph() != graph_.get()) { if (n->owningGraph() != graph_.get()) {

View File

@ -329,6 +329,10 @@ def _is_scalar_list(x):
element_type in scalar_name_to_pytorch.keys() and \ element_type in scalar_name_to_pytorch.keys() and \
(scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys()) (scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys())
def is_caffe2_aten_fallback():
return (_operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
def _get_tensor_rank(x): def _get_tensor_rank(x):
if not _is_tensor(x) or x.type() is None: if not _is_tensor(x) or x.type() is None:
return None return None

View File

@ -551,10 +551,7 @@ def arange(g, *args):
def _dim_arange(g, like, dim): def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like) like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
# Caffe2-specific op if sym_help.is_caffe2_aten_fallback():
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
return g.op("_caffe2::Range", stop) return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None) return arange(g, stop, 4, None, None, None)
@ -643,7 +640,8 @@ def index(g, self, index):
def index_fill(g, self, dim, index, value): def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, "i") dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar") return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value) value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self) value = sym_help._if_scalar_type_as(g, value, self)

View File

@ -565,7 +565,7 @@ def transpose(g, self, dim0, dim1):
# if we don't have dim information we cannot # if we don't have dim information we cannot
# output a permute so use ATen instead # output a permute so use ATen instead
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("transpose", self, dim0_i=dim0, dim1_i=dim1, overload_name="int") return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
else: else:
raise RuntimeError("Unsupported: ONNX export of transpose for tensor " raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
"of unknown rank.") "of unknown rank.")
@ -1581,7 +1581,8 @@ def index_put(g, self, indices_list_value, values, accumulate):
def index_fill(g, self, dim, index, value): def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, "i") dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar") return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value)
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value) value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self) value = sym_help._if_scalar_type_as(g, value, self)
@ -1647,9 +1648,7 @@ def type_as(g, self, other):
@parse_args("v", "v", "i", "f") @parse_args("v", "v", "i", "f")
def cosine_similarity(g, x1, x2, dim, eps): def cosine_similarity(g, x1, x2, dim, eps):
# preserve legacy behavior for Caffe2 if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
torch.onnx._CAFFE2_ATEN_FALLBACK:
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
cross = sym_help._reducesum_helper(g, mul(g, x1, x2), cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
axes_i=[dim], keepdims_i=0) axes_i=[dim], keepdims_i=0)
@ -2599,10 +2598,7 @@ rnn_relu = _one_hidden_rnn("RNN_RELU")
def _dim_arange(g, like, dim): def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like) like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
# Caffe2-specific op if sym_help.is_caffe2_aten_fallback():
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
return g.op("_caffe2::Range", stop) return g.op("_caffe2::Range", stop)
else: else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)

View File

@ -190,7 +190,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_fuse_addmm(graph) torch._C._jit_pass_fuse_addmm(graph)
torch._C._jit_pass_lint(graph) torch._C._jit_pass_lint(graph)
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version, is_caffe2_aten_fallback
torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lower_all_tuples(graph) torch._C._jit_pass_lower_all_tuples(graph)
@ -212,13 +212,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_onnx_remove_print(graph) torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph) torch._C._jit_pass_onnx_preprocess_caffe2(graph)
# Caffe2-specific optimization
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
torch.onnx.symbolic_helper._quantized_ops.clear() torch.onnx.symbolic_helper._quantized_ops.clear()
# Unpack quantized weights for conv and linear ops and insert into graph. # Unpack quantized weights for conv and linear ops and insert into graph.
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback) torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback())
if is_caffe2_aten_fallback: if is_caffe2_aten_fallback():
# Insert permutes before and after each conv op to ensure correct order. # Insert permutes before and after each conv op to ensure correct order.
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
@ -289,7 +286,7 @@ def warn_on_static_input_change(input_states):
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
# This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX # This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX
if operator_export_type is not operator_export_type.ONNX: if operator_export_type is not operator_export_type.ONNX and torch.onnx._CAFFE2_ATEN_FALLBACK:
if arg_value is True: if arg_value is True:
warnings.warn("`{}' can be set to True only when 'operator_export_type' is " warnings.warn("`{}' can be set to True only when 'operator_export_type' is "
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', " "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
@ -959,7 +956,8 @@ def _add_attribute(node, key, value, aten):
name, kind = m.group(1), m.group(2) name, kind = m.group(1), m.group(2)
if _is_onnx_list(value): if _is_onnx_list(value):
kind += "s" kind += "s"
if aten: from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
if aten and is_caffe2_aten_fallback():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
# Caffe2 proto does not support tensor attribute. # Caffe2 proto does not support tensor attribute.
if value.numel() > 1: if value.numel() > 1:
@ -1119,14 +1117,13 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
try: try:
import torch import torch
from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
import torch.onnx.symbolic_registry as sym_registry import torch.onnx.symbolic_registry as sym_registry
sym_registry.register_version("", opset_version) sym_registry.register_version("", opset_version)
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only. # Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and if is_caffe2_aten_fallback() and opset_version == 9:
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback and opset_version == 9:
import torch.onnx.symbolic_caffe2 import torch.onnx.symbolic_caffe2
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version) torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
@ -1137,11 +1134,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
else: else:
ns_op_name = n.kind() ns_op_name = n.kind()
ns, op_name = ns_op_name.split("::") ns, op_name = ns_op_name.split("::")
domain = ns domain = ns
if ns == "aten": if ns == "aten":
domain = "" domain = ""
elif ns == "quantized" and is_caffe2_aten_fallback: elif ns == "quantized" and is_caffe2_aten_fallback():
domain = "caffe2" domain = "caffe2"
if sym_registry.is_registered_op(op_name, domain, opset_version): if sym_registry.is_registered_op(op_name, domain, opset_version):
@ -1165,7 +1161,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
outputs = n.outputsSize() outputs = n.outputsSize()
attrs["outputs"] = outputs attrs["outputs"] = outputs
return g.at(op_name, *inputs, aten=True, **attrs) return g.at(op_name, *inputs, **attrs)
else: else:
raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version) raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
except RuntimeError: except RuntimeError:
@ -1181,6 +1177,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat
# Generate an ONNX ATen op node. # Generate an ONNX ATen op node.
def _aten_op(g, operator, *args, overload_name="", **kwargs): def _aten_op(g, operator, *args, overload_name="", **kwargs):
kwargs["aten"] = True
return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs) return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs)
@ -1315,7 +1312,7 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
torch._C.Graph.op = _graph_op # type: ignore[attr-defined] torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
torch._C.Graph.at = _aten_op # type: ignore[attr-defined] torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
torch._C.Block.op = _block_op # type: ignore[attr-defined] torch._C.Block.op = _block_op # type: ignore[attr-defined]
torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined] torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment] torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]

View File

@ -295,6 +295,14 @@ def skipIfNoQNNPACK(fn):
fn(*args, **kwargs) fn(*args, **kwargs)
return wrapper return wrapper
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
raise unittest.SkipTest(reason)
else:
fn(*args, **kwargs)
return wrapper
try: try:
import torchvision # noqa: F401 import torchvision # noqa: F401
HAS_TORCHVISION = True HAS_TORCHVISION = True

View File

@ -1023,7 +1023,6 @@ def skipIfNoLapack(fn):
fn(*args, **kwargs) fn(*args, **kwargs)
return wrapper return wrapper
def skipIfNotRegistered(op_name, message): def skipIfNotRegistered(op_name, message):
"""Wraps the decorator to hide the import of the `core`. """Wraps the decorator to hide the import of the `core`.
@ -1045,6 +1044,17 @@ def skipIfNotRegistered(op_name, message):
skipper = unittest.skip("Cannot import `caffe2.python.core`") skipper = unittest.skip("Cannot import `caffe2.python.core`")
return skipper return skipper
def _decide_skip_caffe2(expect_caffe2, reason):
def skip_dec(func):
def wrapper(self):
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
raise unittest.SkipTest(reason)
return func(self)
return wrapper
return skip_dec
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
def skipIfNoSciPy(fn): def skipIfNoSciPy(fn):
@wraps(fn) @wraps(fn)