diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 46b43aecce28..5e9858cf3401 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -247,7 +247,7 @@ namespace c10 { _(onnx, Less) \ _(onnx, LessOrEqual) \ _(onnx, Not) \ - _(onnx, ATen) \ + _(aten, ATen) \ _(onnx, Split) \ _(onnx, ConstantOfShape) \ _(onnx, Cast) \ @@ -316,7 +316,8 @@ namespace c10 { _(attr, new_axis) \ _(attr, warn_id) \ _(attr, allowzero) \ - _(attr, seen_none) + _(attr, seen_none) \ + _(attr, overload_name) enum class _keys : unique_t { #define DEFINE_KEY(ns, s) ns##_##s, diff --git a/binaries/bench_gen/bench_gen.py b/binaries/bench_gen/bench_gen.py index 2b344c1f5947..8684e07ee4fd 100755 --- a/binaries/bench_gen/bench_gen.py +++ b/binaries/bench_gen/bench_gen.py @@ -59,7 +59,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Utilitity to generate Caffe2 benchmark models.") + description="Utility to generate Caffe2 benchmark models.") parser.add_argument("operator", help="Caffe2 operator to benchmark.") parser.add_argument("-b", "--blob", help="Instantiate a blob --blob name=dim1,dim2,dim3", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 27e75803e236..5045432e7130 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1943,6 +1943,8 @@ if(BUILD_PYTHON) # ---[ Python. if(BUILD_CAFFE2) add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS}) + target_compile_definitions(torch PRIVATE BUILD_CAFFE2) + target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2) if(USE_NUMPY) target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY") target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy) diff --git a/caffe2/contrib/aten/README.md b/caffe2/contrib/aten/README.md index 593079ef1393..79a4276a65f8 100644 --- a/caffe2/contrib/aten/README.md +++ b/caffe2/contrib/aten/README.md @@ -72,7 +72,7 @@ class Add(torch.autograd.Function): @staticmethod def symbolic(g, a, b): - return g.op("ATen", a, b, operator_s = "add") + return g.at("add", a, b) @staticmethod def forward(ctx, a, b): diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 97c64631921a..b22b840c25ad 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -179,7 +179,7 @@ private: std::vector attrs; for (const auto i : c10::irange(operator_def.arg_size())) { auto & attr = operator_def.arg(i); - if(attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name" ) { + if (attr.name() == "operator" || attr.name() == "type" || attr.name() == "overload_name") { continue; } attrs.push_back(attr.name()); diff --git a/caffe2/contrib/aten/aten_test.py b/caffe2/contrib/aten/aten_test.py index 4a025c3b1802..6574884245f8 100644 --- a/caffe2/contrib/aten/aten_test.py +++ b/caffe2/contrib/aten/aten_test.py @@ -1,9 +1,4 @@ - - - - - -from caffe2.python import core, dyndep +from caffe2.python import core from hypothesis import given import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/contrib/aten/docs/sample.py b/caffe2/contrib/aten/docs/sample.py index 53ce19b86e89..6896f2379d8c 100644 --- a/caffe2/contrib/aten/docs/sample.py +++ b/caffe2/contrib/aten/docs/sample.py @@ -38,8 +38,8 @@ torch.onnx.export(MyModule(), # graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu), # %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): # %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input) -# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2) -# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y) +# %3 : Tensor = aten::ATen[operator="mul"](%2, %2) +# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y) # return (%4) graph = onnx.load(f.name) diff --git a/caffe2/python/benchmark_generator.py b/caffe2/python/benchmark_generator.py index 5342cb314a5b..c557ebfc9536 100644 --- a/caffe2/python/benchmark_generator.py +++ b/caffe2/python/benchmark_generator.py @@ -106,7 +106,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Utilitity to generate Caffe2 benchmark models.") + description="Utility to generate Caffe2 benchmark models.") parser.add_argument("operator", help="Caffe2 operator to benchmark.") parser.add_argument("-b", "--blob", help="Instantiate a blob --blob name=dim1,dim2,dim3", diff --git a/test/expect/TestPytorchExportModes.test_aten_fallback.expect b/test/expect/TestPytorchExportModes.test_aten_fallback.expect index d5cfb31cfeef..83c481fd7e9b 100644 --- a/test/expect/TestPytorchExportModes.test_aten_fallback.expect +++ b/test/expect/TestPytorchExportModes.test_aten_fallback.expect @@ -11,7 +11,7 @@ ModelProto { nodes: [ Node {type: "Add", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "ATen", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]} + Node {type: "ATen", domain: "org.pytorch.aten", inputs: [2,3], outputs: [4,5], attributes: [{ name: 'operator', type: string, value: 'qr'}, { name: 'overload_name', type: string, value: ''}]} ] } opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], diff --git a/test/expect/TestPytorchExportModes.test_onnx_aten.expect b/test/expect/TestPytorchExportModes.test_onnx_aten.expect index 85f4f8573d1c..3c2960f91f96 100644 --- a/test/expect/TestPytorchExportModes.test_onnx_aten.expect +++ b/test/expect/TestPytorchExportModes.test_onnx_aten.expect @@ -9,7 +9,7 @@ ModelProto { outputs: [{name: "2", type:Tensor dims: 3 4}] initializers: [] nodes: [ - Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]} + Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}, { name: 'overload_name', type: string, value: ''}]} ] } opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], diff --git a/test/expect/TestScript.test_listconstruct_erasure.expect b/test/expect/TestScript.test_listconstruct_erasure.expect index 7d4bb8d97fc0..8172b3fe0c76 100644 --- a/test/expect/TestScript.test_listconstruct_erasure.expect +++ b/test/expect/TestScript.test_listconstruct_erasure.expect @@ -13,7 +13,7 @@ ModelProto { Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]}, Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]}, - Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]} + Node {type: "ATen", domain: "org.pytorch.aten", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}, { name: 'overload_name', type: string, value: ''}]} ] } opset_import: [OperatorSetIdProto { domain: }OperatorSetIdProto { domain: org.pytorch.aten}], diff --git a/test/onnx/expect/TestOperators.test_at_op.expect b/test/onnx/expect/TestOperators.test_at_op.expect index 8d4ba07ddcc8..8890f6535756 100644 --- a/test/onnx/expect/TestOperators.test_at_op.expect +++ b/test/onnx/expect/TestOperators.test_at_op.expect @@ -18,6 +18,7 @@ graph { s: "" type: STRING } + domain: "org.pytorch.aten" } name: "torch_jit" input { @@ -56,3 +57,7 @@ graph { opset_import { version: 13 } +opset_import { + domain: "org.pytorch.aten" + version: 1 +} diff --git a/test/onnx/expect/TestOperators.test_aten_embedding_2.expect b/test/onnx/expect/TestOperators.test_aten_embedding_2.expect index 20c7b94bb7e3..98779b99d98d 100644 --- a/test/onnx/expect/TestOperators.test_aten_embedding_2.expect +++ b/test/onnx/expect/TestOperators.test_aten_embedding_2.expect @@ -6,19 +6,24 @@ graph { input: "emb.weight" input: "input_1" output: "onnx::Add_3" - name: "ATenOp_0" - op_type: "ATenOp" + name: "ATen_0" + op_type: "ATen" attribute { name: "custom_attributes_json" s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}" type: STRING } attribute { - name: "name" - s: "aten::embedding" + name: "operator" + s: "embedding" type: STRING } - domain: "com.microsoft" + attribute { + name: "overload_name" + s: "" + type: STRING + } + domain: "org.pytorch.aten" } node { input: "onnx::Add_3" @@ -145,27 +150,11 @@ graph { } } } - value_info { - name: "onnx::Add_3" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_param: "ATenOponnx::Add_3_dim_0" - } - dim { - dim_param: "ATenOponnx::Add_3_dim_1" - } - } - } - } - } } opset_import { version: 12 } opset_import { - domain: "com.microsoft" + domain: "org.pytorch.aten" version: 1 } diff --git a/test/onnx/expect/TestOperators.test_embedding_bags.expect b/test/onnx/expect/TestOperators.test_embedding_bags.expect index dfa1afddee30..eb4a94b75590 100644 --- a/test/onnx/expect/TestOperators.test_embedding_bags.expect +++ b/test/onnx/expect/TestOperators.test_embedding_bags.expect @@ -2,30 +2,9 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { - node { - output: "onnx::Cast_3" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\001\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "onnx::Cast_3" - output: "onnx::Loop_4" - op_type: "Cast" - attribute { - name: "to" - i: 9 - type: INT - } - } node { output: "5" + name: "Constant_0" op_type: "Constant" attribute { name: "value" @@ -40,10 +19,12 @@ graph { node { input: "input" output: "onnx::Gather_6" + name: "Shape_1" op_type: "Shape" } node { output: "onnx::Gather_7" + name: "Constant_2" op_type: "Constant" attribute { name: "value" @@ -58,6 +39,7 @@ graph { input: "onnx::Gather_6" input: "onnx::Gather_7" output: "onnx::Unsqueeze_8" + name: "Gather_3" op_type: "Gather" attribute { name: "axis" @@ -67,6 +49,7 @@ graph { } node { output: "onnx::Unsqueeze_9" + name: "Constant_4" op_type: "Constant" attribute { name: "value" @@ -82,12 +65,14 @@ graph { input: "onnx::Unsqueeze_8" input: "onnx::Unsqueeze_9" output: "onnx::Concat_10" + name: "Unsqueeze_5" op_type: "Unsqueeze" } node { input: "offsets" input: "onnx::Concat_10" output: "onnx::Slice_11" + name: "Concat_6" op_type: "Concat" attribute { name: "axis" @@ -97,6 +82,7 @@ graph { } node { output: "onnx::Slice_12" + name: "Constant_7" op_type: "Constant" attribute { name: "value" @@ -110,6 +96,7 @@ graph { } node { output: "onnx::Slice_13" + name: "Constant_8" op_type: "Constant" attribute { name: "value" @@ -123,6 +110,7 @@ graph { } node { output: "onnx::Slice_14" + name: "Constant_9" op_type: "Constant" attribute { name: "value" @@ -136,6 +124,7 @@ graph { } node { output: "onnx::Slice_15" + name: "Constant_10" op_type: "Constant" attribute { name: "value" @@ -154,15 +143,18 @@ graph { input: "onnx::Slice_12" input: "onnx::Slice_15" output: "onnx::Shape_16" + name: "Slice_11" op_type: "Slice" } node { input: "onnx::Shape_16" output: "onnx::Gather_17" + name: "Shape_12" op_type: "Shape" } node { output: "onnx::Gather_18" + name: "Constant_13" op_type: "Constant" attribute { name: "value" @@ -177,6 +169,7 @@ graph { input: "onnx::Gather_17" input: "onnx::Gather_18" output: "onnx::Loop_19" + name: "Gather_14" op_type: "Gather" attribute { name: "axis" @@ -186,8 +179,9 @@ graph { } node { input: "onnx::Loop_19" - input: "onnx::Loop_4" + input: "onnx::Loop_33" output: "20" + name: "Loop_15" op_type: "Loop" attribute { name: "body" @@ -196,7 +190,7 @@ graph { input: "onnx::Slice_11" input: "21" output: "23" - name: "Gather_0" + name: "Gather_16" op_type: "Gather" attribute { name: "axis" @@ -208,7 +202,7 @@ graph { input: "onnx::Shape_16" input: "21" output: "24" - name: "Gather_1" + name: "Gather_17" op_type: "Gather" attribute { name: "axis" @@ -218,7 +212,7 @@ graph { } node { output: "25" - name: "Constant_2" + name: "Constant_18" op_type: "Constant" attribute { name: "value" @@ -234,12 +228,12 @@ graph { input: "23" input: "25" output: "26" - name: "Unsqueeze_3" + name: "Unsqueeze_19" op_type: "Unsqueeze" } node { output: "27" - name: "Constant_4" + name: "Constant_20" op_type: "Constant" attribute { name: "value" @@ -255,7 +249,7 @@ graph { input: "24" input: "27" output: "28" - name: "Unsqueeze_5" + name: "Unsqueeze_21" op_type: "Unsqueeze" } node { @@ -264,14 +258,14 @@ graph { input: "28" input: "5" output: "29" - name: "Slice_6" + name: "Slice_22" op_type: "Slice" } node { input: "weight" input: "29" output: "30" - name: "Gather_7" + name: "Gather_23" op_type: "Gather" attribute { name: "axis" @@ -282,7 +276,7 @@ graph { node { input: "30" output: "31" - name: "ReduceMean_8" + name: "ReduceMean_24" op_type: "ReduceMean" attribute { name: "axes" @@ -296,9 +290,9 @@ graph { } } node { - input: "onnx::Loop_4" + input: "onnx::Loop_33" output: "32" - name: "Cast_9" + name: "Cast_25" op_type: "Cast" attribute { name: "to" @@ -362,6 +356,11 @@ graph { name: "weight" raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?" } + initializer { + data_type: 9 + name: "onnx::Loop_33" + raw_data: "\001" + } input { name: "input" type { @@ -404,6 +403,16 @@ graph { } } } + input { + name: "onnx::Loop_33" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } output { name: "20" type { diff --git a/test/onnx/expect/TestOperators.test_layer_norm_aten.expect b/test/onnx/expect/TestOperators.test_layer_norm_aten.expect index d7b7ac561130..94dbc3582be9 100644 --- a/test/onnx/expect/TestOperators.test_layer_norm_aten.expect +++ b/test/onnx/expect/TestOperators.test_layer_norm_aten.expect @@ -7,6 +7,7 @@ graph { input: "weight" input: "bias" output: "3" + name: "ATen_0" op_type: "ATen" attribute { name: "cudnn_enable" @@ -34,6 +35,7 @@ graph { s: "" type: STRING } + domain: "org.pytorch.aten" } name: "torch_jit" initializer { @@ -130,3 +132,7 @@ graph { opset_import { version: 13 } +opset_import { + domain: "org.pytorch.aten" + version: 1 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index ca69f0fb0306..505df2028ea6 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -20,12 +20,15 @@ import os import shutil import tempfile import torch.testing._internal.common_utils as common +from torch.testing._internal.common_utils import skipIfCaffe2 '''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data] --no-onnx: no onnx python dependence --produce-onnx-test-data: generate onnx test data --accept: accept onnx updates and overwrite models ''' +import unittest +unittest.TestCase.maxDiff = None _onnx_test = False # flag to produce onnx test cases. _onnx_dep = True # flag to import onnx package. @@ -322,6 +325,7 @@ class TestOperators(TestCase): x = torch.randn(20, 16, 50) self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) + @skipIfCaffe2 def test_at_op(self): x = torch.randn(3, 4) @@ -339,7 +343,8 @@ class TestOperators(TestCase): def forward(self, x): return MyFun.apply(x) - self.assertONNX(MyModule(), x) + self.assertONNX(MyModule(), x, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) def test_clip(self): x = torch.randn(3, 4, requires_grad=True) @@ -588,6 +593,7 @@ class TestOperators(TestCase): self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x, keep_initializers_as_inputs=True) + @skipIfCaffe2 def test_embedding_bags(self): emb_bag = nn.EmbeddingBag(10, 8) input = torch.tensor([1, 2, 3, 4]).long() @@ -787,6 +793,7 @@ class TestOperators(TestCase): input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) self.assertONNX(BitshiftModel(), (input, input2), opset_version=11) + @skipIfCaffe2 def test_layer_norm_aten(self): model = torch.nn.LayerNorm([10, 10]) x = torch.randn(20, 5, 10, 10) @@ -954,7 +961,7 @@ class TestOperators(TestCase): f'"sparse":{str(sparse).lower()}' '}' ) - output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', + output = g.at("embedding", weight, indices, custom_attributes_json_s=custom_attributes_json) return output @@ -978,6 +985,7 @@ class TestOperators(TestCase): unregister_custom_op_symbolic('::embedding', _onnx_opset_version) # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. + @skipIfCaffe2 def test_aten_embedding_2(self): _onnx_opset_version = 12 @@ -990,7 +998,7 @@ class TestOperators(TestCase): f'"sparse":{str(sparse).lower()}' '}' ) - output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', + output = g.at("embedding", weight, indices, custom_attributes_json_s=custom_attributes_json) # do shape inference and set it via setType @@ -1016,7 +1024,9 @@ class TestOperators(TestCase): x = torch.ones(32, dtype=torch.long) y = torch.randn(1, 8) self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'], - dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}}) + dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}}, + keep_initializers_as_inputs=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) unregister_custom_op_symbolic('::embedding', _onnx_opset_version) diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index ec287cd89fa1..d06575c51bf2 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -62,6 +62,8 @@ from torch.testing._internal.common_quantized import ( override_qengines, ) from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_utils import skipIfNoCaffe2 + from hypothesis import given from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu @@ -1464,6 +1466,7 @@ class TestQuantizeEagerONNXExport(JitTestCase): onnx_model = export_to_onnx(model, data, input_names) @skipIfNoFBGEMM + @skipIfNoCaffe2 def test_lower_graph_linear(self): model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float) data_numpy = np.random.rand(1, 2, 5).astype(np.float32) @@ -1471,6 +1474,7 @@ class TestQuantizeEagerONNXExport(JitTestCase): self._test_lower_graph_impl(model, data) @skipIfNoFBGEMM + @skipIfNoCaffe2 def test_lower_graph_conv2d(self): model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float) data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4dddf7b33d71..cb122c35695c 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -362,7 +362,6 @@ if(USE_NUMPY) target_compile_definitions(torch_python PRIVATE USE_NUMPY) endif() -list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2) if(HAVE_SOVERSION) set_target_properties(torch_python PROPERTIES VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index bc1d7449a880..fb9ed28fdcf2 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -1905,7 +1905,8 @@ static std::unordered_set nodeTypeReliableForTracer = { "onnx::Cast", "onnx::Constant", "onnx::Relu", - "com.microsoft::Gelu"}; + "com.microsoft::Gelu", + "aten::ATen"}; void UpdateReliable( torch::jit::Value* output, diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 6e3da2f74626..a924811b8593 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -113,7 +113,7 @@ void validateBlock( WithInsertPoint guard(node); auto* new_node = b->owningGraph()->insertNode(b->owningGraph()->create( - Symbol(::c10::onnx::ATen), + Symbol(::c10::aten::ATen), node->inputs(), node->outputs().size())); for (size_t i = 0; i < node->outputs().size(); ++i) { @@ -1163,8 +1163,8 @@ void GraphEncoder::EncodeIntermediateValueInfo( const Value* v) { // Motivation is to encode ValueInfo for onnx local function nodes. auto n = v->node(); - if (n->kind().is_onnx()) { - // Encode value info only for non-onnx nodes. + if (n->kind().is_onnx() || n->kind().is_aten()) { + // Encode value info only for non-onnx or non-ATen nodes. return; } if (n->owningGraph() != graph_.get()) { diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 4de7876eec68..321bdc0ed5b9 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -329,6 +329,10 @@ def _is_scalar_list(x): element_type in scalar_name_to_pytorch.keys() and \ (scalar_name_to_pytorch[element_type] in cast_pytorch_to_onnx.keys()) +def is_caffe2_aten_fallback(): + return (_operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and + torch.onnx._CAFFE2_ATEN_FALLBACK) + def _get_tensor_rank(x): if not _is_tensor(x) or x.type() is None: return None diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index c5765b3435bb..51cce345bbe2 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -551,10 +551,7 @@ def arange(g, *args): def _dim_arange(g, like, dim): like_shape = g.op("Shape", like) stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) - # Caffe2-specific op - is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and - torch.onnx._CAFFE2_ATEN_FALLBACK) - if is_caffe2_aten_fallback: + if sym_help.is_caffe2_aten_fallback(): return g.op("_caffe2::Range", stop) return arange(g, stop, 4, None, None, None) @@ -643,7 +640,8 @@ def index(g, self, index): def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, "i") if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: - return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar") + return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value) + expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 7ed13d9dac61..2ec1554717d7 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -565,7 +565,7 @@ def transpose(g, self, dim0, dim1): # if we don't have dim information we cannot # output a permute so use ATen instead if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: - return g.at("transpose", self, dim0_i=dim0, dim1_i=dim1, overload_name="int") + return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) else: raise RuntimeError("Unsupported: ONNX export of transpose for tensor " "of unknown rank.") @@ -1581,7 +1581,8 @@ def index_put(g, self, indices_list_value, values, accumulate): def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, "i") if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: - return g.at("index_fill", self, index, value, dim_i=dim_value, overload_name="int_Scalar") + return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value) + expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) @@ -1647,9 +1648,7 @@ def type_as(g, self, other): @parse_args("v", "v", "i", "f") def cosine_similarity(g, x1, x2, dim, eps): - # preserve legacy behavior for Caffe2 - if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \ - torch.onnx._CAFFE2_ATEN_FALLBACK: + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) cross = sym_help._reducesum_helper(g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0) @@ -2599,10 +2598,7 @@ rnn_relu = _one_hidden_rnn("RNN_RELU") def _dim_arange(g, like, dim): like_shape = g.op("Shape", like) stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) - # Caffe2-specific op - is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and - torch.onnx._CAFFE2_ATEN_FALLBACK) - if is_caffe2_aten_fallback: + if sym_help.is_caffe2_aten_fallback(): return g.op("_caffe2::Range", stop) else: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 313a3783c2d0..560446405705 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -190,7 +190,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_fuse_addmm(graph) torch._C._jit_pass_lint(graph) - from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version + from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version, is_caffe2_aten_fallback torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lower_all_tuples(graph) @@ -212,13 +212,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_onnx_remove_print(graph) torch._C._jit_pass_onnx_preprocess_caffe2(graph) - # Caffe2-specific optimization - is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and - torch.onnx._CAFFE2_ATEN_FALLBACK) torch.onnx.symbolic_helper._quantized_ops.clear() # Unpack quantized weights for conv and linear ops and insert into graph. - torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback) - if is_caffe2_aten_fallback: + torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict, is_caffe2_aten_fallback()) + if is_caffe2_aten_fallback(): # Insert permutes before and after each conv op to ensure correct order. torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) @@ -289,7 +286,7 @@ def warn_on_static_input_change(input_states): def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): # This helper method resolves the arguments that are ignored when export_type != operator_export_type.ONNX - if operator_export_type is not operator_export_type.ONNX: + if operator_export_type is not operator_export_type.ONNX and torch.onnx._CAFFE2_ATEN_FALLBACK: if arg_value is True: warnings.warn("`{}' can be set to True only when 'operator_export_type' is " "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', " @@ -959,7 +956,8 @@ def _add_attribute(node, key, value, aten): name, kind = m.group(1), m.group(2) if _is_onnx_list(value): kind += "s" - if aten: + from torch.onnx.symbolic_helper import is_caffe2_aten_fallback + if aten and is_caffe2_aten_fallback(): if isinstance(value, torch.Tensor): # Caffe2 proto does not support tensor attribute. if value.numel() > 1: @@ -1119,14 +1117,13 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat try: import torch from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version + from torch.onnx.symbolic_helper import is_caffe2_aten_fallback import torch.onnx.symbolic_registry as sym_registry sym_registry.register_version("", opset_version) # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. - is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and - torch.onnx._CAFFE2_ATEN_FALLBACK) - if is_caffe2_aten_fallback and opset_version == 9: + if is_caffe2_aten_fallback() and opset_version == 9: import torch.onnx.symbolic_caffe2 torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version) @@ -1137,11 +1134,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat else: ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") - domain = ns if ns == "aten": domain = "" - elif ns == "quantized" and is_caffe2_aten_fallback: + elif ns == "quantized" and is_caffe2_aten_fallback(): domain = "caffe2" if sym_registry.is_registered_op(op_name, domain, opset_version): @@ -1165,7 +1161,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} outputs = n.outputsSize() attrs["outputs"] = outputs - return g.at(op_name, *inputs, aten=True, **attrs) + return g.at(op_name, *inputs, **attrs) else: raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version) except RuntimeError: @@ -1181,6 +1177,7 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat # Generate an ONNX ATen op node. def _aten_op(g, operator, *args, overload_name="", **kwargs): + kwargs["aten"] = True return g.op("ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs) @@ -1315,7 +1312,7 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): torch._C.Graph.op = _graph_op # type: ignore[attr-defined] -torch._C.Graph.at = _aten_op # type: ignore[attr-defined] +torch._C.Graph.at = _aten_op # type: ignore[attr-defined] torch._C.Block.op = _block_op # type: ignore[attr-defined] torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined] torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment] diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index e93273faa278..44554bab11fa 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -295,6 +295,14 @@ def skipIfNoQNNPACK(fn): fn(*args, **kwargs) return wrapper + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not torch.onnx._CAFFE2_ATEN_FALLBACK: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + return wrapper + try: import torchvision # noqa: F401 HAS_TORCHVISION = True diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bd8732c9369d..cd67f5fe6f1a 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1023,7 +1023,6 @@ def skipIfNoLapack(fn): fn(*args, **kwargs) return wrapper - def skipIfNotRegistered(op_name, message): """Wraps the decorator to hide the import of the `core`. @@ -1045,6 +1044,17 @@ def skipIfNotRegistered(op_name, message): skipper = unittest.skip("Cannot import `caffe2.python.core`") return skipper +def _decide_skip_caffe2(expect_caffe2, reason): + def skip_dec(func): + def wrapper(self): + if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2: + raise unittest.SkipTest(reason) + return func(self) + return wrapper + return skip_dec + +skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2") +skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available") def skipIfNoSciPy(fn): @wraps(fn)