mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Prepare for an update to the XNNPACK submodule (#72642)
Summary: - Target Sha1: ae108ef49aa5623b896fc93d4298c49d1750d9ba - Make USE_XNNPACK a dependent option on cmake minimum version 3.12 - Print USE_XNNPACK under cmake options summary, and print the availability from collet_env.py - Skip XNNPACK based tests when XNNPACK is not available - Add SkipIfNoXNNPACK wrapper to skip tests - Update cmake version for xenial-py3.7-gcc5.4 image to 3.12.4 - This is required for the backwards compatibility test. The PyTorch op schema is XNNPACK dependent. See, aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp for example. The nightly version is assumed to have USE_XNNPACK=ON, so with this change we ensure that the test build can also have XNNPACK. - HACK: skipping test_xnnpack_integration tests on ROCM Pull Request resolved: https://github.com/pytorch/pytorch/pull/72642 Reviewed By: kimishpatel Differential Revision: D34456794 Pulled By: digantdesai fbshipit-source-id: 85dbfe0211de7846d8a84321b14fdb061cd6c037 (cherry picked from commit 6cf48e7b64d6979962d701b5d493998262cc8bfa)
This commit is contained in:
committed by
PyTorch MergeBot
parent
abb55c53b3
commit
b2054d3025
@ -90,7 +90,7 @@ case "$image" in
|
|||||||
;;
|
;;
|
||||||
pytorch-linux-xenial-py3.7-gcc5.4)
|
pytorch-linux-xenial-py3.7-gcc5.4)
|
||||||
ANACONDA_PYTHON_VERSION=3.7
|
ANACONDA_PYTHON_VERSION=3.7
|
||||||
CMAKE_VERSION=3.10.3
|
CMAKE_VERSION=3.12.4 # To make sure XNNPACK is enabled for the BACKWARDS_COMPAT_TEST used with this image
|
||||||
GCC_VERSION=5
|
GCC_VERSION=5
|
||||||
PROTOBUF=yes
|
PROTOBUF=yes
|
||||||
DB=yes
|
DB=yes
|
||||||
|
@ -285,7 +285,14 @@ option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
|
|||||||
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
|
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
|
||||||
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
|
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
|
||||||
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
|
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
|
||||||
option(USE_XNNPACK "Use XNNPACK" ON)
|
# option USE_XNNPACK: try to enable xnnpack by default.
|
||||||
|
set(XNNPACK_MIN_CMAKE_VER 3.12)
|
||||||
|
cmake_dependent_option(
|
||||||
|
USE_XNNPACK "Use XNNPACK. Requires cmake >= ${XNNPACK_MIN_CMAKE_VER}." ON
|
||||||
|
"CMAKE_VERSION VERSION_GREATER_EQUAL ${XNNPACK_MIN_CMAKE_VER}" OFF)
|
||||||
|
if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER})
|
||||||
|
message(WARNING "USE_XNNPACK is set to OFF. XNNPACK requires CMake version ${XNNPACK_MIN_CMAKE_VER} or greater.")
|
||||||
|
endif()
|
||||||
option(USE_ZMQ "Use ZMQ" OFF)
|
option(USE_ZMQ "Use ZMQ" OFF)
|
||||||
option(USE_ZSTD "Use ZSTD" OFF)
|
option(USE_ZSTD "Use ZSTD" OFF)
|
||||||
# Ensure that an MKLDNN build is the default for x86 CPUs
|
# Ensure that an MKLDNN build is the default for x86 CPUs
|
||||||
|
@ -171,6 +171,7 @@ function(caffe2_print_configuration_summary)
|
|||||||
message(STATUS " USE_PROF : ${USE_PROF}")
|
message(STATUS " USE_PROF : ${USE_PROF}")
|
||||||
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
|
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")
|
||||||
message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}")
|
message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}")
|
||||||
|
message(STATUS " USE_XNNPACK : ${USE_XNNPACK}")
|
||||||
message(STATUS " USE_REDIS : ${USE_REDIS}")
|
message(STATUS " USE_REDIS : ${USE_REDIS}")
|
||||||
message(STATUS " USE_ROCKSDB : ${USE_ROCKSDB}")
|
message(STATUS " USE_ROCKSDB : ${USE_ROCKSDB}")
|
||||||
message(STATUS " USE_ZMQ : ${USE_ZMQ}")
|
message(STATUS " USE_ZMQ : ${USE_ZMQ}")
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch.backends.xnnpack
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||||
|
|
||||||
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||||
def check_replacement(
|
def check_replacement(
|
||||||
@ -36,6 +36,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
original_source_ranges[replacements[node.kind()]],
|
original_source_ranges[replacements[node.kind()]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_replace_conv1d_with_conv2d(self):
|
def test_replace_conv1d_with_conv2d(self):
|
||||||
class TestConv1d(torch.nn.Module):
|
class TestConv1d(torch.nn.Module):
|
||||||
def __init__(self, weight, bias):
|
def __init__(self, weight, bias):
|
||||||
@ -63,6 +64,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
|
jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
|
def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
|
||||||
class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
|
class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -139,6 +141,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
|
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_insert_pre_packed_linear_op(self):
|
def test_insert_pre_packed_linear_op(self):
|
||||||
self.check_replacement(
|
self.check_replacement(
|
||||||
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
|
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
|
||||||
@ -230,6 +233,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
|
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
|
def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
|
||||||
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
||||||
linear_activation=F.hardtanh,
|
linear_activation=F.hardtanh,
|
||||||
@ -238,6 +242,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
conv2d_activation_kind="aten::hardtanh_"
|
conv2d_activation_kind="aten::hardtanh_"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
|
def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
|
||||||
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
||||||
linear_activation=F.hardtanh_,
|
linear_activation=F.hardtanh_,
|
||||||
@ -246,6 +251,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
conv2d_activation_kind="aten::hardtanh"
|
conv2d_activation_kind="aten::hardtanh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
|
def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
|
||||||
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
||||||
linear_activation=F.relu,
|
linear_activation=F.relu,
|
||||||
@ -254,6 +260,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
|||||||
conv2d_activation_kind="aten::relu_"
|
conv2d_activation_kind="aten::relu_"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
|
def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
|
||||||
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
|
||||||
linear_activation=F.relu_,
|
linear_activation=F.relu_,
|
||||||
|
@ -3,9 +3,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.backends.xnnpack
|
|
||||||
import torch.utils.bundled_inputs
|
import torch.utils.bundled_inputs
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
|
||||||
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
|
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
|
||||||
from torch.utils.mobile_optimizer import (LintCode,
|
from torch.utils.mobile_optimizer import (LintCode,
|
||||||
generate_mobile_module_lints,
|
generate_mobile_module_lints,
|
||||||
@ -24,9 +23,7 @@ FileCheck = torch._C.FileCheck
|
|||||||
|
|
||||||
class TestOptimizer(TestCase):
|
class TestOptimizer(TestCase):
|
||||||
|
|
||||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
@skipIfNoXNNPACK
|
||||||
" XNNPACK must be enabled for these tests."
|
|
||||||
" Please build with USE_XNNPACK=1.")
|
|
||||||
def test_optimize_for_mobile(self):
|
def test_optimize_for_mobile(self):
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
input_channels_per_group = 6
|
input_channels_per_group = 6
|
||||||
@ -265,9 +262,7 @@ class TestOptimizer(TestCase):
|
|||||||
rtol=1e-2,
|
rtol=1e-2,
|
||||||
atol=1e-3)
|
atol=1e-3)
|
||||||
|
|
||||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
@skipIfNoXNNPACK
|
||||||
" XNNPACK must be enabled for these tests."
|
|
||||||
" Please build with USE_XNNPACK=1.")
|
|
||||||
def test_quantized_conv_no_asan_failures(self):
|
def test_quantized_conv_no_asan_failures(self):
|
||||||
# There were ASAN failures when fold_conv_bn was run on
|
# There were ASAN failures when fold_conv_bn was run on
|
||||||
# already quantized conv modules. Verifying that this does
|
# already quantized conv modules. Verifying that this does
|
||||||
@ -361,6 +356,7 @@ class TestOptimizer(TestCase):
|
|||||||
bi_module_lint_list = generate_mobile_module_lints(bi_module)
|
bi_module_lint_list = generate_mobile_module_lints(bi_module)
|
||||||
self.assertEqual(len(bi_module_lint_list), 0)
|
self.assertEqual(len(bi_module_lint_list), 0)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
def test_preserve_bundled_inputs_methods(self):
|
def test_preserve_bundled_inputs_methods(self):
|
||||||
class MyBundledInputModule(torch.nn.Module):
|
class MyBundledInputModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -415,9 +411,7 @@ class TestOptimizer(TestCase):
|
|||||||
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
|
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
|
||||||
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
||||||
|
|
||||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
@skipIfNoXNNPACK
|
||||||
" XNNPACK must be enabled for these tests."
|
|
||||||
" Please build with USE_XNNPACK=1.")
|
|
||||||
def test_hoist_conv_packed_params(self):
|
def test_hoist_conv_packed_params(self):
|
||||||
|
|
||||||
if 'qnnpack' not in torch.backends.quantized.supported_engines:
|
if 'qnnpack' not in torch.backends.quantized.supported_engines:
|
||||||
@ -511,6 +505,7 @@ class TestOptimizer(TestCase):
|
|||||||
m_optim_res = m_optim(data)
|
m_optim_res = m_optim(data)
|
||||||
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
|
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
|
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
|
||||||
def test_mobilenet_optimize_for_mobile(self):
|
def test_mobilenet_optimize_for_mobile(self):
|
||||||
m = torchvision.models.mobilenet_v3_small()
|
m = torchvision.models.mobilenet_v3_small()
|
||||||
|
@ -10,9 +10,10 @@ import urllib
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.backends.xnnpack
|
||||||
import torch.utils.model_dump
|
import torch.utils.model_dump
|
||||||
import torch.utils.mobile_optimizer
|
import torch.utils.mobile_optimizer
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
|
||||||
from torch.testing._internal.common_quantized import supported_qengines
|
from torch.testing._internal.common_quantized import supported_qengines
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +171,7 @@ class TestModelDump(TestCase):
|
|||||||
qmodel = self.get_quant_model()
|
qmodel = self.get_quant_model()
|
||||||
self.do_dump_model(torch.jit.script(qmodel))
|
self.do_dump_model(torch.jit.script(qmodel))
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
|
||||||
def test_optimized_quantized_model(self):
|
def test_optimized_quantized_model(self):
|
||||||
qmodel = self.get_quant_model()
|
qmodel = self.get_quant_model()
|
||||||
|
@ -53,7 +53,6 @@ class TestXNNPACKOps(TestCase):
|
|||||||
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
|
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
|
||||||
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
@given(batch_size=st.integers(0, 3),
|
@given(batch_size=st.integers(0, 3),
|
||||||
input_channels_per_group=st.integers(1, 32),
|
input_channels_per_group=st.integers(1, 32),
|
||||||
height=st.integers(5, 64),
|
height=st.integers(5, 64),
|
||||||
|
@ -63,6 +63,7 @@ from torch._six import string_classes
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.backends.cudnn
|
import torch.backends.cudnn
|
||||||
import torch.backends.mkl
|
import torch.backends.mkl
|
||||||
|
import torch.backends.xnnpack
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
import functools
|
import functools
|
||||||
@ -978,6 +979,14 @@ def _test_function(fn, device):
|
|||||||
return fn(self, device)
|
return fn(self, device)
|
||||||
return run_test_function
|
return run_test_function
|
||||||
|
|
||||||
|
def skipIfNoXNNPACK(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not torch.backends.xnnpack.enabled:
|
||||||
|
raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
|
||||||
|
else:
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
def skipIfNoLapack(fn):
|
def skipIfNoLapack(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
@ -42,6 +42,7 @@ SystemEnv = namedtuple('SystemEnv', [
|
|||||||
'hip_runtime_version',
|
'hip_runtime_version',
|
||||||
'miopen_runtime_version',
|
'miopen_runtime_version',
|
||||||
'caching_allocator_config',
|
'caching_allocator_config',
|
||||||
|
'is_xnnpack_available',
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@ -292,6 +293,9 @@ def get_cachingallocator_config():
|
|||||||
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
|
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
|
||||||
return ca_config
|
return ca_config
|
||||||
|
|
||||||
|
def is_xnnpack_available():
|
||||||
|
import torch.backends.xnnpack
|
||||||
|
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||||
|
|
||||||
def get_env_info():
|
def get_env_info():
|
||||||
run_lambda = run
|
run_lambda = run
|
||||||
@ -339,6 +343,7 @@ def get_env_info():
|
|||||||
clang_version=get_clang_version(run_lambda),
|
clang_version=get_clang_version(run_lambda),
|
||||||
cmake_version=get_cmake_version(run_lambda),
|
cmake_version=get_cmake_version(run_lambda),
|
||||||
caching_allocator_config=get_cachingallocator_config(),
|
caching_allocator_config=get_cachingallocator_config(),
|
||||||
|
is_xnnpack_available=is_xnnpack_available(),
|
||||||
)
|
)
|
||||||
|
|
||||||
env_info_fmt = """
|
env_info_fmt = """
|
||||||
@ -362,6 +367,7 @@ Nvidia driver version: {nvidia_driver_version}
|
|||||||
cuDNN version: {cudnn_version}
|
cuDNN version: {cudnn_version}
|
||||||
HIP runtime version: {hip_runtime_version}
|
HIP runtime version: {hip_runtime_version}
|
||||||
MIOpen runtime version: {miopen_runtime_version}
|
MIOpen runtime version: {miopen_runtime_version}
|
||||||
|
Is XNNPACK available: {is_xnnpack_available}
|
||||||
|
|
||||||
Versions of relevant libraries:
|
Versions of relevant libraries:
|
||||||
{pip_packages}
|
{pip_packages}
|
||||||
|
Reference in New Issue
Block a user