diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 5e91588e15f1..67e457ed3f1e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/passes/inline_fork_wait.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/cuda_graph_fuser.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/graph_rewrite_helper.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp @@ -431,6 +432,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/xnnpack_rewrite.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/fuse_linear.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/freeze_module.cpp ${TORCH_SRC_DIR}/csrc/jit/runtime/print_handler.cpp diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index 2a1f85db57ef..e38029706b69 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -5,6 +5,7 @@ import unittest import torch import torch.backends.xnnpack from torch.nn import functional as F +from torch.testing import FileCheck import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_utils import TestCase, run_tests from hypothesis import given, assume @@ -360,5 +361,136 @@ class TestXNNPACKSerDes(TestCase): torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) +@unittest.skipUnless(torch.backends.xnnpack.enabled, + " XNNPACK must be enabled for these tests." + " Please build with USE_XNNPACK=1.") +class TestXNNPACKRewritePass(TestCase): + def test_linear(self): + def validate_transformed_module(module_name, pattern_count_map, data_shape): + scripted_model = torch.jit.script(module_name()) + input_data = torch.rand(data_shape) + ref_result = scripted_model(input_data) + torch._C._jit_pass_insert_xnnpack_ops(scripted_model._c) + + buffer = io.BytesIO() + torch.jit.save(scripted_model, buffer) + buffer.seek(0) + deserialized_scripted_model = torch.jit.load(buffer) + file_check = FileCheck() + for pattern, v in pattern_count_map.items(): + if (v == 0): + file_check.check(pattern) + elif (v == -1): + file_check.check_not(pattern) + else: + file_check.check_count(pattern, v, exactly=True) + file_check.run(deserialized_scripted_model.graph) + xnnpack_result = deserialized_scripted_model(input_data) + torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + + data_shape = [2, 3, 32] + weight_output_dim = 24 + weight_shape = (weight_output_dim, data_shape[-1]) + + class Linear(torch.nn.Module): + def __init__(self): + super(Linear, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape))) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + class LinearNoBias(torch.nn.Module): + def __init__(self): + super(LinearNoBias, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape))) + + def forward(self, x): + return F.linear(x, self.weight, None) + + # Linear with bias pattern. + pattern_count_map = {"Tensor = prim::CallFunction": -1, + "_xnnpack::linear_prepack": 1, + "_xnnpack::linear_packed": 1} + validate_transformed_module(Linear, pattern_count_map, data_shape) + validate_transformed_module(LinearNoBias, pattern_count_map, data_shape) + + # Conv params + batch_size = 2 + input_channels_per_group = 6 + height = 16 + width = 16 + output_channels_per_group = 6 + groups = 4 + kernel_h = kernel_w = 3 + stride_h = stride_w = 1 + pad_h = pad_w = 1 + dilation = 1 + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_h, kernel_w) + strides = (stride_h, stride_w) + paddings = (pad_h, pad_w) + dilations = (dilation, dilation) + conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) + conv_bias_shape = (output_channels) + + class Conv2D(torch.nn.Module): + def __init__(self): + super(Conv2D, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape))) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape))) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "_xnnpack::conv2d_prepack": 1, + "_xnnpack::conv2d_packed": 1} + validate_transformed_module(Conv2D, pattern_count_map, data_shape) + + input_data = torch.rand((batch_size, input_channels, height, width)) + conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) + conv_bias = torch.rand((output_channels)) + result = F.conv2d(input_data, conv_weight, conv_bias, + strides, paddings, dilations, groups) + linear_input_shape = result.shape[1] + linear_weight_shape = (weight_output_dim, linear_input_shape) + + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape))) + self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape)))) + self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape))) + self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + o = F.conv2d(x, self.conv_weight, self.conv_bias, + self.strides, self.paddings, self.dilations, self.groups) + o = o.permute([0, 2, 3, 1]) + o = F.linear(o, self.linear_weight, self.linear_bias) + return F.relu(o) + + pattern_count_map = {"Tensor = aten::conv2d": -1, + "_xnnpack::conv2d_prepack": 1, + "_xnnpack::conv2d_packed": 1, + "Tensor = prim::CallFunction": -1, + "_xnnpack::linear_prepack": 1, + "_xnnpack::linear_packed": 1} + validate_transformed_module(M, pattern_count_map, data_shape) + + if __name__ == "__main__": run_tests() diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 0ca0464eccf0..af287cf1be79 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -120,6 +120,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/erase_number_types.cpp", "torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp", "torch/csrc/jit/passes/graph_fuser.cpp", + "torch/csrc/jit/passes/graph_rewrite_helper.cpp", "torch/csrc/jit/passes/cuda_graph_fuser.cpp", "torch/csrc/jit/passes/guard_elimination.cpp", "torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp", @@ -136,6 +137,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/peephole.cpp", "torch/csrc/jit/serialization/python_print.cpp", "torch/csrc/jit/passes/quantization.cpp", + "torch/csrc/jit/passes/xnnpack_rewrite.cpp", "torch/csrc/jit/passes/fuse_linear.cpp", "torch/csrc/jit/passes/freeze_module.cpp", "torch/csrc/jit/passes/remove_expands.cpp", diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp new file mode 100644 index 000000000000..3eb7bc8f53a3 --- /dev/null +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -0,0 +1,82 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace graph_rewrite_helper { + +std::string getFuncName(Value* func_value) { + auto func_node = func_value->node(); + auto func = func_node->output()->type()->expect()->function(); + const auto& qname = func->qualname(); + const auto& name = qname.qualifiedName(); + auto rdot_idx = name.rfind('.'); + if (rdot_idx != std::string::npos) { + return name.substr(rdot_idx + 1, name.length()); + } else { + return name; + } +} + +Value* getValue( + const std::string& name, + const std::unordered_map& match_vmap, + const std::unordered_map& vmap) { + return match_vmap.at(vmap.at(name)); +} + +c10::optional getIValue( + const std::string& name, + const std::unordered_map& match_vmap, + const std::unordered_map& vmap) { + return toIValue(getValue(name, match_vmap, vmap)); +} + +void replaceConvolutionWithConv2d(std::shared_ptr& graph) { + std::string convolution = R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool): + %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, + %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) + return (%r) )"; + + std::string conv2d = R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool): + %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) + return (%r) )"; + + // Filter the unsupported case + auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + auto transposed_value = + getIValue("transposed", match_vmap, vmap).value().toBool(); + auto benchmark_value = + getIValue("benchmark", match_vmap, vmap).value().toBool(); + auto deterministic_value = + getIValue("deterministic", match_vmap, vmap).value().toBool(); + auto cudnn_enabled_value = + getIValue("cudnn_enabled", match_vmap, vmap).value().toBool(); + auto output_padding_value = + getIValue("output_padding", match_vmap, vmap).value().toIntList(); + + if (!transposed_value && !benchmark_value && !deterministic_value && + cudnn_enabled_value && (output_padding_value[0] == 0) && + (output_padding_value[1] == 0)) { + return true; + } + return false; + }; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(convolution, conv2d); + rewriter.runOnGraph(graph, filter); +} + +} // namespace graph_rewrite_helper +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h new file mode 100644 index 000000000000..1da61d23cac3 --- /dev/null +++ b/torch/csrc/jit/passes/graph_rewrite_helper.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace graph_rewrite_helper { + +std::string getFuncName(Value* func_value); +Value* getValue( + const std::string& name, + const std::unordered_map& match_vmap, + const std::unordered_map& vmap); +c10::optional getIValue( + const std::string& name, + const std::unordered_map& match_vmap, + const std::unordered_map& vmap); +void replaceConvolutionWithConv2d(std::shared_ptr& graph); + +} // namespace graph_rewrite_helper +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp index 92fdb0d2ecbf..9f477ee05ad1 100644 --- a/torch/csrc/jit/passes/quantization.cpp +++ b/torch/csrc/jit/passes/quantization.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,11 @@ namespace { using OptionalModuleVector = std::vector>; using ModuleMethodVector = std::vector>; using NameModuleVector = std::vector>; +using graph_rewrite_helper::getValue; +using graph_rewrite_helper::getIValue; +using graph_rewrite_helper::getFuncName; +using graph_rewrite_helper::replaceConvolutionWithConv2d; + // Map of quantization parameter name and value // for example _scale, _zero_point, // _scalar_type and _axis(for per channel quantization) @@ -58,20 +64,6 @@ struct PatternsAndModules { Module packed_params_module; }; -static Value* getValue( - const std::string& name, - const std::unordered_map& match_vmap, - const std::unordered_map& vmap) { - return match_vmap.at(vmap.at(name)); -} - -static c10::optional getIValue( - const std::string& name, - const std::unordered_map& match_vmap, - const std::unordered_map& vmap) { - return toIValue(getValue(name, match_vmap, vmap)); -} - void fillQConfigMap( const Module& module, const QConfigDict& qconfig_dict, @@ -97,19 +89,6 @@ void fillQConfigMap( } } -std::string getFuncName(Value* func_value) { - auto func_node = func_value->node(); - auto func = func_node->output()->type()->expect()->function(); - const auto& qname = func->qualname(); - const auto& name = qname.qualifiedName(); - auto rdot_idx = name.rfind('.'); - if (rdot_idx != std::string::npos) { - return name.substr(rdot_idx + 1, name.length()); - } else { - return name; - } -} - bool isFunctionNode(Node* n, const std::vector& call_funcs, const std::vector& aten_funcs) { @@ -667,45 +646,6 @@ Module getObserverModuleFor(Value* v, const QConfig& qconfig) { isWeightOfConvOrLinear(v) ? std::get<1>(qconfig) : std::get<0>(qconfig); } -void replaceConvolutionWithConv2d(std::shared_ptr& graph) { - std::string convolution = R"( -graph(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled): - %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) - return (%r) )"; - - std::string conv2d = R"( -graph(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled): - %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) - return (%r) )"; - - // Filter the unsupported case - auto filter = [](const Match& match, - const std::unordered_map& vmap) { - const auto& match_vmap = match.values_map; - auto transposed_value = - getIValue("transposed", match_vmap, vmap).value().toBool(); - auto benchmark_value = - getIValue("benchmark", match_vmap, vmap).value().toBool(); - auto deterministic_value = - getIValue("deterministic", match_vmap, vmap).value().toBool(); - auto cudnn_enabled_value = - getIValue("cudnn_enabled", match_vmap, vmap).value().toBool(); - auto output_padding_value = - getIValue("output_padding", match_vmap, vmap).value().toIntList(); - - if (!transposed_value && !benchmark_value && !deterministic_value && - cudnn_enabled_value && (output_padding_value[0] == 0) && - (output_padding_value[1] == 0)) { - return true; - } - return false; - }; - - SubgraphRewriter rewriter; - rewriter.RegisterRewritePattern(convolution, conv2d); - rewriter.runOnGraph(graph, filter); -} - ModuleMethodVector InsertObserversHelper::getInvokedMethods( Module& module, const std::string& method_name) { diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp new file mode 100644 index 000000000000..65905555e89f --- /dev/null +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +#ifdef USE_XNNPACK + +namespace { + +void insertXNNPACKLinearOp(std::shared_ptr& graph) { + std::string linear_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %r = prim::CallFunction(%linear, %input, %weight, %bias) + return (%r))"; + std::string xnnpack_pattern_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %packed_weight_bias = _xnnpack::linear_prepack(%weight, %bias) + %res = _xnnpack::linear_packed(%input, %packed_weight_bias) + return (%res))"; + std::string linear_pattern = R"( + graph(%input, %weight, %bias): + %r = aten::linear(%input, %weight, %bias) + return (%r))"; + std::string xnnpack_pattern = R"( + graph(%input, %weight, %bias): + %packed_weight_bias = _xnnpack::linear_prepack(%weight, %bias) + %res = _xnnpack::linear_packed(%input, %packed_weight_bias) + return (%res))"; + + auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + auto linear_value = match_vmap.at(vmap.at("linear")); + auto func_name = graph_rewrite_helper::getFuncName(linear_value); + if (func_name == "linear") { + return true; + } + return false; + }; + + SubgraphRewriter linear_call_fn_rewriter; + linear_call_fn_rewriter.RegisterRewritePattern( + linear_before_inline, xnnpack_pattern_before_inline); + linear_call_fn_rewriter.runOnGraph(graph, filter); + + SubgraphRewriter linear_rewriter; + linear_rewriter.RegisterRewritePattern(linear_pattern, xnnpack_pattern); + linear_rewriter.runOnGraph(graph); +} + +void insertXNNPACKConv2dOp(std::shared_ptr& graph) { + // Replace _convolution with conv2d + graph_rewrite_helper::replaceConvolutionWithConv2d(graph); + + std::string conv_2d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): + %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%r) )"; + + std::string xnnpack_conv2d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): + %packed_weight_bias = _xnnpack::conv2d_prepack(%weight, %bias, %stride, %padding, %dilation, %groups) + %r = _xnnpack::conv2d_packed(%input, %packed_weight_bias) + return (%r) )"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(conv_2d_pattern, xnnpack_conv2d_pattern); + rewriter.runOnGraph(graph); +} + +} // namespace + +void insertXNNPACKOps(std::shared_ptr& graph) { + ConstantPooling(graph); + ConstantPropagation(graph); + insertXNNPACKLinearOp(graph); + insertXNNPACKConv2dOp(graph); +} + +void insertXNNPACKOps(script::Module& module) { + for (auto& method : module.get_methods()) { + auto graph = method.graph(); + insertXNNPACKOps(graph); + } + for (script::Module m : module.children()) { + insertXNNPACKOps(m); + } +} + +#else + +void insertXNNPACKOps(std::shared_ptr& graph) { + TORCH_INTERNAL_ASSERT( + "XNNPACK is not enabled. Please build with USE_XNNPACK=1"); +} + +void insertXNNPACKOps(script::Module& module) { + TORCH_INTERNAL_ASSERT( + "XNNPACK is not enabled. Please build with USE_XNNPACK=1"); +} + +#endif +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h new file mode 100644 index 000000000000..59840f118bad --- /dev/null +++ b/torch/csrc/jit/passes/xnnpack_rewrite.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +TORCH_API void insertXNNPACKOps(std::shared_ptr& graph); +TORCH_API void insertXNNPACKOps(script::Module& module); +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index edaaca51af1f..c2b9bd88e3ac 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -409,6 +410,16 @@ void initJITBindings(PyObject* module) { [](Graph& g, std::vector inps) { return debugGetFusedKernelCode(g, inps); }) + .def( + "_jit_pass_insert_xnnpack_ops", + [](std::shared_ptr& graph) { + return insertXNNPACKOps(graph); + }) + .def( + "_jit_pass_insert_xnnpack_ops", + [](script::Module& module) { + return insertXNNPACKOps(module); + }) .def( "_jit_pass_onnx_unpack_quantized_weights", [](std::shared_ptr& graph,