Prepare for an update to the XNNPACK submodule (#72642)

Summary:
- Target Sha1: ae108ef49aa5623b896fc93d4298c49d1750d9ba
- Make USE_XNNPACK a dependent option on cmake minimum version 3.12
- Print USE_XNNPACK under cmake options summary, and print the
  availability from collet_env.py
- Skip XNNPACK based tests when XNNPACK is not available
    - Add SkipIfNoXNNPACK wrapper to skip tests
- Update cmake version for xenial-py3.7-gcc5.4 image to 3.12.4
    - This is required for the backwards compatibility test.
      The PyTorch op schema is XNNPACK dependent. See,
      aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp for
      example. The nightly version is assumed to have USE_XNNPACK=ON,
      so with this change we ensure that the test build can also
      have XNNPACK.
- HACK: skipping test_xnnpack_integration tests on ROCM

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72642

Reviewed By: kimishpatel

Differential Revision: D34456794

Pulled By: digantdesai

fbshipit-source-id: 85dbfe0211de7846d8a84321b14fdb061cd6c037
(cherry picked from commit 6cf48e7b64d6979962d701b5d493998262cc8bfa)
This commit is contained in:
Digant Desai
2022-02-24 15:24:21 -08:00
committed by PyTorch MergeBot
parent abb55c53b3
commit b2054d3025
9 changed files with 42 additions and 16 deletions

View File

@ -90,7 +90,7 @@ case "$image" in
;;
pytorch-linux-xenial-py3.7-gcc5.4)
ANACONDA_PYTHON_VERSION=3.7
CMAKE_VERSION=3.10.3
CMAKE_VERSION=3.12.4 # To make sure XNNPACK is enabled for the BACKWARDS_COMPAT_TEST used with this image
GCC_VERSION=5
PROTOBUF=yes
DB=yes

View File

@ -285,7 +285,14 @@ option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
option(USE_XNNPACK "Use XNNPACK" ON)
# option USE_XNNPACK: try to enable xnnpack by default.
set(XNNPACK_MIN_CMAKE_VER 3.12)
cmake_dependent_option(
USE_XNNPACK "Use XNNPACK. Requires cmake >= ${XNNPACK_MIN_CMAKE_VER}." ON
"CMAKE_VERSION VERSION_GREATER_EQUAL ${XNNPACK_MIN_CMAKE_VER}" OFF)
if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER})
message(WARNING "USE_XNNPACK is set to OFF. XNNPACK requires CMake version ${XNNPACK_MIN_CMAKE_VER} or greater.")
endif()
option(USE_ZMQ "Use ZMQ" OFF)
option(USE_ZSTD "Use ZSTD" OFF)
# Ensure that an MKLDNN build is the default for x86 CPUs

View File

@ -171,6 +171,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_PROF : ${USE_PROF}")
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}")
message(STATUS " USE_XNNPACK : ${USE_XNNPACK}")
message(STATUS " USE_REDIS : ${USE_REDIS}")
message(STATUS " USE_ROCKSDB : ${USE_ROCKSDB}")
message(STATUS " USE_ZMQ : ${USE_ZMQ}")

View File

@ -2,9 +2,9 @@
import torch
import torch._C
import torch.backends.xnnpack
import torch.nn.functional as F
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoXNNPACK
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
def check_replacement(
@ -36,6 +36,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
original_source_ranges[replacements[node.kind()]],
)
@skipIfNoXNNPACK
def test_replace_conv1d_with_conv2d(self):
class TestConv1d(torch.nn.Module):
def __init__(self, weight, bias):
@ -63,6 +64,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
)
@skipIfNoXNNPACK
def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
def __init__(
@ -139,6 +141,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
)
@skipIfNoXNNPACK
def test_insert_pre_packed_linear_op(self):
self.check_replacement(
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
@ -230,6 +233,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
)
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.hardtanh,
@ -238,6 +242,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::hardtanh_"
)
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.hardtanh_,
@ -246,6 +251,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::hardtanh"
)
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.relu,
@ -254,6 +260,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation_kind="aten::relu_"
)
@skipIfNoXNNPACK
def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
linear_activation=F.relu_,

View File

@ -3,9 +3,8 @@
import unittest
import torch
import torch.nn as nn
import torch.backends.xnnpack
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import (LintCode,
generate_mobile_module_lints,
@ -24,9 +23,7 @@ FileCheck = torch._C.FileCheck
class TestOptimizer(TestCase):
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
@skipIfNoXNNPACK
def test_optimize_for_mobile(self):
batch_size = 2
input_channels_per_group = 6
@ -265,9 +262,7 @@ class TestOptimizer(TestCase):
rtol=1e-2,
atol=1e-3)
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
@skipIfNoXNNPACK
def test_quantized_conv_no_asan_failures(self):
# There were ASAN failures when fold_conv_bn was run on
# already quantized conv modules. Verifying that this does
@ -361,6 +356,7 @@ class TestOptimizer(TestCase):
bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0)
@skipIfNoXNNPACK
def test_preserve_bundled_inputs_methods(self):
class MyBundledInputModule(torch.nn.Module):
def __init__(self):
@ -415,9 +411,7 @@ class TestOptimizer(TestCase):
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
@skipIfNoXNNPACK
def test_hoist_conv_packed_params(self):
if 'qnnpack' not in torch.backends.quantized.supported_engines:
@ -511,6 +505,7 @@ class TestOptimizer(TestCase):
m_optim_res = m_optim(data)
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
@skipIfNoXNNPACK
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
def test_mobilenet_optimize_for_mobile(self):
m = torchvision.models.mobilenet_v3_small()

View File

@ -10,9 +10,10 @@ import urllib
import unittest
import torch
import torch.backends.xnnpack
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
from torch.testing._internal.common_quantized import supported_qengines
@ -170,6 +171,7 @@ class TestModelDump(TestCase):
qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel))
@skipIfNoXNNPACK
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self):
qmodel = self.get_quant_model()

View File

@ -53,7 +53,6 @@ class TestXNNPACKOps(TestCase):
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),

View File

@ -63,6 +63,7 @@ from torch._six import string_classes
from torch import Tensor
import torch.backends.cudnn
import torch.backends.mkl
import torch.backends.xnnpack
from enum import Enum
from statistics import mean
import functools
@ -978,6 +979,14 @@ def _test_function(fn, device):
return fn(self, device)
return run_test_function
def skipIfNoXNNPACK(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if not torch.backends.xnnpack.enabled:
raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
else:
fn(*args, **kwargs)
return wrapper
def skipIfNoLapack(fn):
@wraps(fn)

View File

@ -42,6 +42,7 @@ SystemEnv = namedtuple('SystemEnv', [
'hip_runtime_version',
'miopen_runtime_version',
'caching_allocator_config',
'is_xnnpack_available',
])
@ -292,6 +293,9 @@ def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
return ca_config
def is_xnnpack_available():
import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
def get_env_info():
run_lambda = run
@ -339,6 +343,7 @@ def get_env_info():
clang_version=get_clang_version(run_lambda),
cmake_version=get_cmake_version(run_lambda),
caching_allocator_config=get_cachingallocator_config(),
is_xnnpack_available=is_xnnpack_available(),
)
env_info_fmt = """
@ -362,6 +367,7 @@ Nvidia driver version: {nvidia_driver_version}
cuDNN version: {cudnn_version}
HIP runtime version: {hip_runtime_version}
MIOpen runtime version: {miopen_runtime_version}
Is XNNPACK available: {is_xnnpack_available}
Versions of relevant libraries:
{pip_packages}