Prepare for an update to the XNNPACK submodule (#72642)

Summary:
- Target Sha1: ae108ef49aa5623b896fc93d4298c49d1750d9ba
- Make USE_XNNPACK a dependent option on cmake minimum version 3.12
- Print USE_XNNPACK under cmake options summary, and print the
  availability from collet_env.py
- Skip XNNPACK based tests when XNNPACK is not available
    - Add SkipIfNoXNNPACK wrapper to skip tests
- Update cmake version for xenial-py3.7-gcc5.4 image to 3.12.4
    - This is required for the backwards compatibility test.
      The PyTorch op schema is XNNPACK dependent. See,
      aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp for
      example. The nightly version is assumed to have USE_XNNPACK=ON,
      so with this change we ensure that the test build can also
      have XNNPACK.
- HACK: skipping test_xnnpack_integration tests on ROCM

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72642

Reviewed By: kimishpatel

Differential Revision: D34456794

Pulled By: digantdesai

fbshipit-source-id: 85dbfe0211de7846d8a84321b14fdb061cd6c037
(cherry picked from commit 6cf48e7b64d6979962d701b5d493998262cc8bfa)
This commit is contained in:
Digant Desai
2022-02-24 15:24:21 -08:00
committed by PyTorch MergeBot
parent abb55c53b3
commit b2054d3025
9 changed files with 42 additions and 16 deletions

View File

@ -90,7 +90,7 @@ case "$image" in
;; ;;
pytorch-linux-xenial-py3.7-gcc5.4) pytorch-linux-xenial-py3.7-gcc5.4)
ANACONDA_PYTHON_VERSION=3.7 ANACONDA_PYTHON_VERSION=3.7
CMAKE_VERSION=3.10.3 CMAKE_VERSION=3.12.4 # To make sure XNNPACK is enabled for the BACKWARDS_COMPAT_TEST used with this image
GCC_VERSION=5 GCC_VERSION=5
PROTOBUF=yes PROTOBUF=yes
DB=yes DB=yes

View File

@ -285,7 +285,14 @@ option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF) option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
option(USE_XNNPACK "Use XNNPACK" ON) # option USE_XNNPACK: try to enable xnnpack by default.
set(XNNPACK_MIN_CMAKE_VER 3.12)
cmake_dependent_option(
USE_XNNPACK "Use XNNPACK. Requires cmake >= ${XNNPACK_MIN_CMAKE_VER}." ON
"CMAKE_VERSION VERSION_GREATER_EQUAL ${XNNPACK_MIN_CMAKE_VER}" OFF)
if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER})
message(WARNING "USE_XNNPACK is set to OFF. XNNPACK requires CMake version ${XNNPACK_MIN_CMAKE_VER} or greater.")
endif()
option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZMQ "Use ZMQ" OFF)
option(USE_ZSTD "Use ZSTD" OFF) option(USE_ZSTD "Use ZSTD" OFF)
# Ensure that an MKLDNN build is the default for x86 CPUs # Ensure that an MKLDNN build is the default for x86 CPUs

View File

@ -171,6 +171,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_PROF : ${USE_PROF}") message(STATUS " USE_PROF : ${USE_PROF}")
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}") message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}") message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}")
message(STATUS " USE_XNNPACK : ${USE_XNNPACK}")
message(STATUS " USE_REDIS : ${USE_REDIS}") message(STATUS " USE_REDIS : ${USE_REDIS}")
message(STATUS " USE_ROCKSDB : ${USE_ROCKSDB}") message(STATUS " USE_ROCKSDB : ${USE_ROCKSDB}")
message(STATUS " USE_ZMQ : ${USE_ZMQ}") message(STATUS " USE_ZMQ : ${USE_ZMQ}")

View File

@ -2,9 +2,9 @@
import torch import torch
import torch._C import torch._C
import torch.backends.xnnpack
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoXNNPACK
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase): class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
def check_replacement( def check_replacement(
@ -36,6 +36,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
original_source_ranges[replacements[node.kind()]], original_source_ranges[replacements[node.kind()]],
) )
@skipIfNoXNNPACK
def test_replace_conv1d_with_conv2d(self): def test_replace_conv1d_with_conv2d(self):
class TestConv1d(torch.nn.Module): class TestConv1d(torch.nn.Module):
def __init__(self, weight, bias): def __init__(self, weight, bias):
@ -63,6 +64,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d, jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
) )
@skipIfNoXNNPACK
def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self): def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module): class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
def __init__( def __init__(
@ -139,6 +141,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_insert_prepacked_ops, jit_pass=torch._C._jit_pass_insert_prepacked_ops,
) )
@skipIfNoXNNPACK
def test_insert_pre_packed_linear_op(self): def test_insert_pre_packed_linear_op(self):
self.check_replacement( self.check_replacement(
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)), model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
@ -230,6 +233,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv, jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
) )
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_1(self): def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d( self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.hardtanh, linear_activation=F.hardtanh,
@ -238,6 +242,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::hardtanh_" conv2d_activation_kind="aten::hardtanh_"
) )
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_2(self): def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d( self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.hardtanh_, linear_activation=F.hardtanh_,
@ -246,6 +251,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::hardtanh" conv2d_activation_kind="aten::hardtanh"
) )
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_3(self): def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d( self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.relu, linear_activation=F.relu,
@ -254,6 +260,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::relu_" conv2d_activation_kind="aten::relu_"
) )
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_4(self): def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d( self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.relu_, linear_activation=F.relu_,

View File

@ -3,9 +3,8 @@
import unittest import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.backends.xnnpack
import torch.utils.bundled_inputs import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
from torch.testing._internal.jit_utils import get_forward, get_forward_graph from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import (LintCode, from torch.utils.mobile_optimizer import (LintCode,
generate_mobile_module_lints, generate_mobile_module_lints,
@ -24,9 +23,7 @@ FileCheck = torch._C.FileCheck
class TestOptimizer(TestCase): class TestOptimizer(TestCase):
@unittest.skipUnless(torch.backends.xnnpack.enabled, @skipIfNoXNNPACK
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
def test_optimize_for_mobile(self): def test_optimize_for_mobile(self):
batch_size = 2 batch_size = 2
input_channels_per_group = 6 input_channels_per_group = 6
@ -265,9 +262,7 @@ class TestOptimizer(TestCase):
rtol=1e-2, rtol=1e-2,
atol=1e-3) atol=1e-3)
@unittest.skipUnless(torch.backends.xnnpack.enabled, @skipIfNoXNNPACK
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
def test_quantized_conv_no_asan_failures(self): def test_quantized_conv_no_asan_failures(self):
# There were ASAN failures when fold_conv_bn was run on # There were ASAN failures when fold_conv_bn was run on
# already quantized conv modules. Verifying that this does # already quantized conv modules. Verifying that this does
@ -361,6 +356,7 @@ class TestOptimizer(TestCase):
bi_module_lint_list = generate_mobile_module_lints(bi_module) bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0) self.assertEqual(len(bi_module_lint_list), 0)
@skipIfNoXNNPACK
def test_preserve_bundled_inputs_methods(self): def test_preserve_bundled_inputs_methods(self):
class MyBundledInputModule(torch.nn.Module): class MyBundledInputModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -415,9 +411,7 @@ class TestOptimizer(TestCase):
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs']) incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
@unittest.skipUnless(torch.backends.xnnpack.enabled, @skipIfNoXNNPACK
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
def test_hoist_conv_packed_params(self): def test_hoist_conv_packed_params(self):
if 'qnnpack' not in torch.backends.quantized.supported_engines: if 'qnnpack' not in torch.backends.quantized.supported_engines:
@ -511,6 +505,7 @@ class TestOptimizer(TestCase):
m_optim_res = m_optim(data) m_optim_res = m_optim(data)
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
@skipIfNoXNNPACK
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision") @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
def test_mobilenet_optimize_for_mobile(self): def test_mobilenet_optimize_for_mobile(self):
m = torchvision.models.mobilenet_v3_small() m = torchvision.models.mobilenet_v3_small()

View File

@ -10,9 +10,10 @@ import urllib
import unittest import unittest
import torch import torch
import torch.backends.xnnpack
import torch.utils.model_dump import torch.utils.model_dump
import torch.utils.mobile_optimizer import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
from torch.testing._internal.common_quantized import supported_qengines from torch.testing._internal.common_quantized import supported_qengines
@ -170,6 +171,7 @@ class TestModelDump(TestCase):
qmodel = self.get_quant_model() qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel)) self.do_dump_model(torch.jit.script(qmodel))
@skipIfNoXNNPACK
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self): def test_optimized_quantized_model(self):
qmodel = self.get_quant_model() qmodel = self.get_quant_model()

View File

@ -53,7 +53,6 @@ class TestXNNPACKOps(TestCase):
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias) output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(0, 3), @given(batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32), input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64), height=st.integers(5, 64),

View File

@ -63,6 +63,7 @@ from torch._six import string_classes
from torch import Tensor from torch import Tensor
import torch.backends.cudnn import torch.backends.cudnn
import torch.backends.mkl import torch.backends.mkl
import torch.backends.xnnpack
from enum import Enum from enum import Enum
from statistics import mean from statistics import mean
import functools import functools
@ -978,6 +979,14 @@ def _test_function(fn, device):
return fn(self, device) return fn(self, device)
return run_test_function return run_test_function
def skipIfNoXNNPACK(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if not torch.backends.xnnpack.enabled:
raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
else:
fn(*args, **kwargs)
return wrapper
def skipIfNoLapack(fn): def skipIfNoLapack(fn):
@wraps(fn) @wraps(fn)

View File

@ -42,6 +42,7 @@ SystemEnv = namedtuple('SystemEnv', [
'hip_runtime_version', 'hip_runtime_version',
'miopen_runtime_version', 'miopen_runtime_version',
'caching_allocator_config', 'caching_allocator_config',
'is_xnnpack_available',
]) ])
@ -292,6 +293,9 @@ def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
return ca_config return ca_config
def is_xnnpack_available():
import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
def get_env_info(): def get_env_info():
run_lambda = run run_lambda = run
@ -339,6 +343,7 @@ def get_env_info():
clang_version=get_clang_version(run_lambda), clang_version=get_clang_version(run_lambda),
cmake_version=get_cmake_version(run_lambda), cmake_version=get_cmake_version(run_lambda),
caching_allocator_config=get_cachingallocator_config(), caching_allocator_config=get_cachingallocator_config(),
is_xnnpack_available=is_xnnpack_available(),
) )
env_info_fmt = """ env_info_fmt = """
@ -362,6 +367,7 @@ Nvidia driver version: {nvidia_driver_version}
cuDNN version: {cudnn_version} cuDNN version: {cudnn_version}
HIP runtime version: {hip_runtime_version} HIP runtime version: {hip_runtime_version}
MIOpen runtime version: {miopen_runtime_version} MIOpen runtime version: {miopen_runtime_version}
Is XNNPACK available: {is_xnnpack_available}
Versions of relevant libraries: Versions of relevant libraries:
{pip_packages} {pip_packages}