Mobile Backend: NHWC memory layout + XNNPACK integration. (#33722)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33722

In order to improve CPU performance on floating-point models on mobile, this PR introduces a new CPU backend for mobile that implements the most common mobile operators with NHWC memory layout support through integration with XNNPACK.

XNNPACK itself, and this codepath, are currently only included in the build, but the actual integration is gated with USE_XNNPACK preprocessor guards.  This preprocessor symbol is intentionally not passed on to the compiler, so as to enable this rollout in multiple stages in follow up PRs.  This changeset will build XNNPACK as part of the build if the identically named USE_XNNPACK CMAKE variable, defaulted to ON, is enabled, but will not actually expose or enable this code path in any other way.

Furthermore, it is worth pointing out that in order to efficiently map models to these operators, some front-end method of exposing this backend to the user is needed.  The less efficient implementation would be to hook these operators into their corresponding native implementations, granted that a series of XNNPACK-specific conditions are met, much like how NNPACK is integrated with PyTorch today for instance.

Having said that, while the above implementation is still expected to outperform NNPACK based on the benchmarks I ran, the above integration would be leave a considerable gap between the performance achieved and the maximum performance potential XNNPACK enables, as it does not provide a way to compute and factor out one-time operations out of the inner most forward() loop.

The more optimal solution, and one we will  decide on soon, would involve either providing a JIT pass that maps nn operators onto these newly introduced operators, while allowing one-time calculations to be factored out, much like quantized mobile models.  Alternatively, new eager-mode modules can also be introduced that would directly call into these implementations either through c10 or some other mechanism, also allowing for decoupling of op creation from op execution.

This PR does not include any of the front end changes  mentioned above.  Neither does it include the mobile threadpool unification present in the original https://github.com/pytorch/pytorch/issues/30644.  Furthermore, this codepath seems to be faster than NNPACK in a good number of use cases, which can potentially allow us to remove NNPACK from aten to make the codebase a little simpler, granted that there is widespread support for such a move.

Regardless, these changes will be introduced gradually and in a more controlled way in subsequent PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/32509

Test Plan:
Build: CI
Functionality: Not exposed

Reviewed By: dreiss

Differential Revision: D20069796

Pulled By: AshkanAliabadi

fbshipit-source-id: d46c1c91d4bea91979ea5bd46971ced5417d309c
This commit is contained in:
Ashkan Aliabadi
2020-02-24 21:53:34 -08:00
committed by Facebook Github Bot
parent 2a4aad7466
commit 6aecfd1e80
23 changed files with 1088 additions and 178 deletions

View File

@ -14,7 +14,7 @@ mkdir -p ${ZIP_DIR}/src
cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/
# build a FAT bianry
cd ${ZIP_DIR}/install/lib
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a)
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a)
for lib in ${target_libs[*]}
do
if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then

4
.gitmodules vendored
View File

@ -118,3 +118,7 @@
ignore = dirty
path = android/libs/fbjni
url = https://github.com/facebookincubator/fbjni.git
[submodule "third_party/XNNPACK"]
path = third_party/XNNPACK
url = https://github.com/AshkanAliabadi/XNNPACK.git
branch = xnnpack_pytorch_merge_temp

View File

@ -185,6 +185,7 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
option(USE_SYSTEM_EIGEN_INSTALL
"Use system Eigen instead of the one under third_party" OFF)
option(USE_TENSORRT "Using Nvidia TensorRT library" OFF)
option(USE_XNNPACK "Use XNNPACK" ON)
option(USE_ZMQ "Use ZMQ" OFF)
option(USE_ZSTD "Use ZSTD" OFF)
cmake_dependent_option(
@ -416,6 +417,10 @@ if(USE_PYTORCH_QNNPACK)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PYTORCH_QNNPACK")
endif()
if(USE_XNNPACK)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK")
endif()
# ---[ Whitelist file if whitelist is specified
include(cmake/Whitelist.cmake)

View File

@ -83,6 +83,7 @@ if (ANDROID_ABI)
import_static_lib(libtorch_cpu)
import_static_lib(libc10)
import_static_lib(libnnpack)
import_static_lib(libXNNPACK)
import_static_lib(libpytorch_qnnpack)
import_static_lib(libeigen_blas)
import_static_lib(libcpuinfo)
@ -98,6 +99,7 @@ if (ANDROID_ABI)
-Wl,--no-whole-archive
libc10
libnnpack
libXNNPACK
libpytorch_qnnpack
libeigen_blas
libcpuinfo
@ -113,6 +115,7 @@ else()
torch_cpu
c10
nnpack
XNNPACK
pytorch_qnnpack
cpuinfo
clog

View File

@ -84,8 +84,11 @@ FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
FILE(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
FILE(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
# XNNPACK
FILE(GLOB native_xnnpack "native/xnnpack/*.cpp")
add_subdirectory(quantized)
set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
if(AT_MKL_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
endif()

View File

@ -775,6 +775,10 @@
- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor
- func: _conv2d_prepack(Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1, float? output_min=None, float? output_max=None) -> Tensor
- func: _conv2d_packed(Tensor packed_weight, Tensor input) -> Tensor
- func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor
- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
@ -1575,6 +1579,10 @@
- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
python_module: nn
- func: _linear_prepack(Tensor weight, Tensor? bias=None, float? output_min=None, float? output_max=None) -> Tensor
- func: _linear_packed(Tensor packed_weight, Tensor input) -> Tensor
- func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
python_module: nn
dispatch:

View File

@ -0,0 +1,55 @@
#pragma once
#include <c10/core/CPUAllocator.h>
namespace at {
namespace native {
// QNNPACK AND XNNPACK may out-of-bound access the input and / or output tensors.
// This behavior will trigger ASAN, and may result in a segfault if the accessed
// memory just so happens to fall on a page the current process has no read access
// to. Here we define a custom allocator that allocates the extra storage required
// to keep this behavior safe.
//
// PreGuardBytes: Number of guard bytes to allocate before the allocation.
// PostGuardBytes: Number of guard bytes to allocate after the allocation.
template <uint32_t PreGuardBytes, uint32_t PostGuardBytes>
class GuardingAllocator final : public at::Allocator {
public:
GuardingAllocator() = default;
virtual ~GuardingAllocator() override = default;
static void deleter(void* pointer) {
const Cast memory{pointer};
c10::free_cpu(memory.as_byte_ptr - kPreGuardBytes);
}
virtual DataPtr allocate(size_t nbytes) const override {
Cast memory{c10::alloc_cpu(kPreGuardBytes + nbytes + kPostGuardBytes)};
memory.as_byte_ptr += kPreGuardBytes;
return {
memory.as_void_ptr,
memory.as_void_ptr,
&deleter,
at::Device(DeviceType::CPU),
};
}
virtual DeleterFnPtr raw_deleter() const override {
return deleter;
}
private:
static constexpr uint32_t kPreGuardBytes = PreGuardBytes;
static constexpr uint32_t kPostGuardBytes = PostGuardBytes;
union Cast final {
void * const as_void_ptr;
uint8_t * as_byte_ptr;
};
};
} // namespace native
} // namespace at

View File

@ -0,0 +1,82 @@
#pragma once
#include <ATen/ATen.h>
#ifdef USE_XNNPACK
#include <xnnpack.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
struct Layout final {
// 4D Activation Maps
struct Activation4D final {
static constexpr size_t batch = 0u;
static constexpr size_t channels = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};
// ND Activation Maps
struct ActivationND final {
// Some operators may not be limited to 4 dimensional tensors. In that scenario,
// XNNPACK denotes that operator with an _nc suffix and expects all dimensions,
// except channels, to be flattened into one argument: batch_size.
static int64_t batch(const IntArrayRef tensor) {
if (C10_UNLIKELY(tensor.empty())) {
return -1;
}
// Handle the case where batch size is zero.
int64_t batch = std::max<int64_t>(1, tensor[0]);
for (size_t index = 1u; index < (tensor.size() - 1u); ++index) {
batch *= tensor[index];
}
return batch;
};
static int64_t channel(const IntArrayRef tensor) {
if (C10_UNLIKELY(tensor.empty())) {
return -1;
}
return tensor.back();
};
};
// Convolution Filters
struct Filter final {
static constexpr size_t output = 0u;
static constexpr size_t input = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};
// Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
struct Parameter final {
static constexpr size_t height = 0u;
static constexpr size_t width = 1u;
};
};
struct Deleter final {
void operator()(const xnn_operator_t op) const {
xnn_delete_operator(op);
}
};
using Operator = std::unique_ptr<xnn_operator, Deleter>;
bool available();
} // namespace internal
} // namespace xnnpack
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,316 @@
#ifdef USE_XNNPACK
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Factory.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
namespace convolution2d {
struct Context final {
Operator convolution_op;
std::vector<int64_t> weight_size;
std::vector<int64_t> padding;
std::vector<int64_t> stride;
std::vector<int64_t> dilation;
static constexpr float kMin = -std::numeric_limits<float>::infinity();
static constexpr float kMax = std::numeric_limits<float>::infinity();
};
namespace {
// Supports NHWC and NCHW FP32 convolutions with any valid
// - kernel size
// - padding
// - stride
// - dilation
// - grouping
// TODO: Decouple and improve error handling and messages.
bool available(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const float output_min,
const float output_max) {
// XNNPACK
return xnnpack::internal::available() &&
// Weight
(4 == weight.ndimension()) &&
(weight.size(Layout::Filter::height) > 0) &&
(weight.size(Layout::Filter::width) > 0) &&
(c10::DeviceType::CPU == weight.device().type()) &&
(kFloat == weight.scalar_type()) &&
// Bias
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
(c10::DeviceType::CPU == bias->device().type()) &&
(kFloat == bias->scalar_type()) &&
(weight.size(Layout::Filter::output)) == bias->size(0))
: true) &&
// Padding
(padding[Layout::Parameter::height] >= 0) &&
(padding[Layout::Parameter::width] >= 0) &&
// Stride
(stride[Layout::Parameter::height] > 0) &&
(stride[Layout::Parameter::width] > 0) &&
// Dilation
(dilation[Layout::Parameter::height] > 0) &&
(dilation[Layout::Parameter::width] > 0) &&
// Groups
(groups > 0) &&
// Input
(weight.size(Layout::Filter::input) > 0) &&
// Output
(weight.size(Layout::Filter::output) > 0) &&
// Output - Groups
((weight.size(Layout::Filter::output) % groups) == 0) &&
// Output Min / Max
(output_max > output_min) &&
true;
}
Context create(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const IntArrayRef padding_,
const IntArrayRef stride_,
const IntArrayRef dilation_,
const int64_t groups,
const float output_min,
const float output_max) {
const auto padding = expand_param_if_needed(padding_, "padding", 2);
const auto stride = expand_param_if_needed(stride_, "stride", 2);
const auto dilation = expand_param_if_needed(dilation_, "dilation", 2);
const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast);
TORCH_CHECK(
available(
weight_nhwc,
bias,
padding,
stride,
dilation,
groups,
output_min,
output_max),
"XNNPACK Convolution not available! "
"Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) "
"parameters are either invalid individually or their combination is not supported by XNNPACK.");
xnn_operator_t convolution_op{};
const xnn_status create_status = xnn_create_convolution2d_nhwc_f32(
padding[Layout::Parameter::height], // input_padding_top
padding[Layout::Parameter::width], // input_padding_right
padding[Layout::Parameter::height], // input_padding_bottom
padding[Layout::Parameter::width], // input_padding_left
weight_nhwc.size(Layout::Filter::height), // kernel_height
weight_nhwc.size(Layout::Filter::width), // kernel_width
stride[Layout::Parameter::height], // subsampling_height
stride[Layout::Parameter::width], // subsampling_width
dilation[Layout::Parameter::height], // dilation_height
dilation[Layout::Parameter::width], // dilation_width
groups, // groups
weight_nhwc.size(Layout::Filter::input), // group_input_channels
weight_nhwc.size(Layout::Filter::output) / groups, // group_output_channels
weight_nhwc.size(Layout::Filter::input) * groups, // input_pixel_stride
weight_nhwc.size(Layout::Filter::output), // output_pixel_stride
weight_nhwc.data_ptr<float>(), // kernel
(bias && bias->defined()) ? bias->data_ptr<float>() : nullptr, // bias
output_min, // output_min
output_max, // output_max
0u, // flags
&convolution_op); // operator
TORCH_CHECK(
xnn_status_success == create_status,
"xnn_create_convolution2d_nhwc_f32 failed!");
return Context{
Operator(convolution_op),
weight_nhwc.sizes().vec(),
padding,
stride,
dilation,
};
}
// TODO: Decouple and improve error handling and messages.
bool usable(const Tensor& input) {
// Input
return (4 == input.ndimension()) &&
(c10::DeviceType::CPU == input.device().type()) &&
(kFloat == input.scalar_type()) &&
(input.size(Layout::Activation4D::batch) > 0) &&
(input.size(Layout::Activation4D::channels) > 0) &&
(input.size(Layout::Activation4D::height) > 0) &&
(input.size(Layout::Activation4D::width) > 0) &&
true;
}
Tensor run(
const Context& context,
const Tensor& input) {
using namespace internal;
const Tensor input_nhwc = input.contiguous(MemoryFormat::ChannelsLast);
TORCH_CHECK(
usable(input_nhwc),
"XNNPACK Convolution not usable! "
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
Tensor output = empty_with_tail_padding(
conv_output_size(
input_nhwc.sizes(),
context.weight_size,
context.padding,
context.stride,
context.dilation),
input_nhwc.options().dtype(),
MemoryFormat::ChannelsLast);
const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32(
context.convolution_op.get(), // operator
input_nhwc.size(Layout::Activation4D::batch), // batch_size
input_nhwc.size(Layout::Activation4D::height), // input_height
input_nhwc.size(Layout::Activation4D::width), // input_width
input_nhwc.data_ptr<float>(), // input
output.data_ptr<float>(), // output
nullptr); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_convolution2d_nhwc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
context.convolution_op.get(), // operator
nullptr); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status,
"xnn_run_operator failed!");
return output.contiguous(input.suggest_memory_format());
}
Tensor create_and_run(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const float output_min,
const float output_max) {
return run(
create(
weight,
bias,
padding,
stride,
dilation,
groups,
output_min,
output_max),
input);
}
} // namespace
} // namespace convolution2d
} // namespace internal
bool use_convolution2d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups) {
return internal::convolution2d::available(
weight,
bias,
padding,
stride,
dilation,
groups,
internal::convolution2d::Context::kMin,
internal::convolution2d::Context::kMax) &&
internal::convolution2d::usable(input);
}
Tensor convolution2d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups) {
return internal::convolution2d::create_and_run(
input,
weight,
bias,
padding,
stride,
dilation,
groups,
internal::convolution2d::Context::kMin,
internal::convolution2d::Context::kMax);
}
} // namespace xnnpack
at::Tensor _conv2d_prepack(
const Tensor& weight,
const Tensor& bias,
const IntArrayRef stride,
const IntArrayRef padding,
const IntArrayRef dilation,
const int64_t groups,
const c10::optional<double> output_min,
const c10::optional<double> output_max) {
return cpp_custom_type_hack::create(
std::make_unique<xnnpack::internal::convolution2d::Context>(
xnnpack::internal::convolution2d::create(
weight,
bias,
padding.vec(),
stride.vec(),
dilation.vec(),
groups,
output_min ? *output_min : xnnpack::internal::convolution2d::Context::kMin,
output_max ? *output_max : xnnpack::internal::convolution2d::Context::kMax)),
weight.options());
}
at::Tensor _conv2d_packed(
const Tensor& packed_weight,
const Tensor& input) {
return xnnpack::internal::convolution2d::run(
cpp_custom_type_hack::cast<xnnpack::internal::convolution2d::Context>(packed_weight),
input);
}
} // namespace native
} // namespace at
namespace caffe2 {
CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::convolution2d::Context);
} // namespace caffe2
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,38 @@
#ifdef USE_XNNPACK
#include <ATen/native/xnnpack/Factory.h>
#include <ATen/native/utils/Allocator.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
Tensor empty_with_tail_padding(
const IntArrayRef size,
const caffe2::TypeMeta dtype,
const c10::MemoryFormat memory_format) {
static GuardingAllocator<0u, XNN_EXTRA_BYTES> allocator;
const int64_t nelements = prod_intlist(size);
Tensor tensor(
c10::make_intrusive<c10::TensorImpl>(
c10::Storage{
dtype,
nelements,
allocator.allocate(nelements * dtype.itemsize()),
&allocator,
/*resizable=*/true,
},
DispatchKeySet{DispatchKey::CPUTensorId}));
return tensor.resize_(size, memory_format);
}
} // namespace internal
} // namespace xnnpack
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,25 @@
#pragma once
#include <ATen/native/xnnpack/Common.h>
#ifdef USE_XNNPACK
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
// TODO: Remove this function when at::native::empty() is modified to accept a
// custom memory allocator.
at::Tensor empty_with_tail_padding(
IntArrayRef size,
const caffe2::TypeMeta dtype,
c10::MemoryFormat memory_format);
} // namespace internal
} // namespace xnnpack
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,63 @@
#ifdef USE_XNNPACK
#include <ATen/native/xnnpack/Common.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
namespace {
bool is_initialized_ = false;
bool initialize() {
using namespace internal;
// This implementation allows for retries.
if (!is_initialized_) {
const xnn_status status = xnn_initialize(nullptr);
is_initialized_ = (xnn_status_success == status);
if (!is_initialized_) {
if (xnn_status_out_of_memory == status) {
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Out of memory.");
} else if (xnn_status_unsupported_hardware == status) {
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unsupported hardware.");
} else {
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unknown error!");
}
}
}
return is_initialized_;
}
bool deinitialize() {
using namespace internal;
// This implementation allows for retries.
if (is_initialized_) {
const xnn_status status = xnn_deinitialize();
is_initialized_ = !(xnn_status_success == status);
if (is_initialized_) {
TORCH_WARN_ONCE("Failed to uninitialize XNNPACK! Reason: Unknown error!");
}
}
return !is_initialized_;
}
} // namespace
bool available() {
// Add extra conditions here that should disable mobile CPU impl at runtime in its totality.
return internal::initialize();
}
} // namespace internal
} // namespace xnnpack
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,223 @@
#ifdef USE_XNNPACK
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Factory.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
namespace linear {
struct Context final {
Operator linear_op;
struct Output final {
int64_t channels;
} output;
static constexpr float kMin = -std::numeric_limits<float>::infinity();
static constexpr float kMax = std::numeric_limits<float>::infinity();
};
namespace {
// Supports NHWC and NCHW FP32 linear operators.
// TODO: Decouple and improve error handling and messages.
bool available(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const float output_min,
const float output_max) {
// XNNPACK
return xnnpack::internal::available() &&
// Weight
(2 == weight.ndimension()) &&
(c10::DeviceType::CPU == weight.device().type()) &&
(kFloat == weight.scalar_type()) &&
// Bias
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
(c10::DeviceType::CPU == bias->device().type()) &&
(kFloat == bias->scalar_type()) &&
(weight.size(Layout::Filter::output)) == bias->size(0))
: true) &&
// Output Min / Max
(output_max > output_min) &&
true;
}
Context create(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const float output_min,
const float output_max) {
const Tensor weight_contig = weight.contiguous();
TORCH_CHECK(
available(
weight_contig,
bias,
output_min,
output_max),
"XNNPACK Linear not available! "
"Reason: The provided (weight, bias, output_min, output_max) parameters are "
"either invalid individually or their combination is not supported by XNNPACK.");
xnn_operator_t linear_op{};
const xnn_status create_status = xnn_create_fully_connected_nc_f32(
weight_contig.size(Layout::Filter::input), // input_channels
weight_contig.size(Layout::Filter::output), // output_channels
weight_contig.size(Layout::Filter::input), // input_pixel_stride
weight_contig.size(Layout::Filter::output), // output_pixel_stride
weight_contig.data_ptr<float>(), // kernel
(bias && bias->defined()) ? bias->data_ptr<float>() : nullptr, // bias
output_min, // output_min
output_max, // output_max
0u, // flags
&linear_op); // operator
TORCH_CHECK(
xnn_status_success == create_status,
"xnn_create_fully_connected_nc_f32 failed!");
return Context{
Operator(linear_op),
{
weight_contig.size(Layout::Filter::output),
}
};
}
// TODO: Decouple and improve error handling and messages.
bool usable(const Tensor& input) {
// Input
return (2 <= input.ndimension()) &&
(c10::DeviceType::CPU == input.device().type()) &&
(kFloat == input.scalar_type()) &&
true;
}
Tensor run(
const Context& context,
const Tensor& input) {
using namespace internal;
const Tensor& input_contig = input.contiguous();
TORCH_CHECK(
usable(input_contig),
"XNNPACK Linear not usable! "
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
const IntArrayRef input_size = input_contig.sizes();
std::vector<int64_t> output_size(input_size.cbegin(), input_size.cend());
output_size.back() = context.output.channels;
Tensor output = empty_with_tail_padding(
output_size,
input_contig.options().dtype(),
input_contig.suggest_memory_format());
const xnn_status setup_status = xnn_setup_fully_connected_nc_f32(
context.linear_op.get(), // operator
Layout::ActivationND::batch(input_contig.sizes()), // Batch,
input_contig.data_ptr<float>(), // input
output.data_ptr<float>(), // output
nullptr); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_fully_connected_nc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
context.linear_op.get(), // operator
nullptr); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status,
"xnn_run_operator failed!");
return output;
}
Tensor create_and_run(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const float output_min,
const float output_max) {
return run(
create(
weight,
bias,
output_min,
output_max),
input);
}
} // namespace
} // namespace linear
} // namespace internal
bool use_linear(
const Tensor& input,
const Tensor& weight,
const Tensor& bias) {
return internal::linear::available(
weight,
bias,
internal::linear::Context::kMin,
internal::linear::Context::kMax) &&
internal::linear::usable(input);
}
Tensor linear(
const Tensor& input,
const Tensor& weight,
const Tensor& bias) {
return internal::linear::create_and_run(
input,
weight,
bias,
internal::linear::Context::kMin,
internal::linear::Context::kMax);
}
} // namespace xnnpack
Tensor _linear_prepack(
const Tensor& weight,
const Tensor& bias,
const c10::optional<double> output_min,
const c10::optional<double> output_max) {
return cpp_custom_type_hack::create(
std::make_unique<xnnpack::internal::linear::Context>(
xnnpack::internal::linear::create(
weight,
bias,
output_min ? *output_min : xnnpack::internal::linear::Context::kMin,
output_max ? *output_max : xnnpack::internal::linear::Context::kMax)),
weight.options());
}
Tensor _linear_packed(
const Tensor& packed_weight,
const Tensor& input) {
return xnnpack::internal::linear::run(
cpp_custom_type_hack::cast<xnnpack::internal::linear::Context>(packed_weight),
input);
}
} // namespace native
} // namespace at
namespace caffe2 {
CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::linear::Context);
} // namespace caffe2
#endif /* USE_XNNPACK */

View File

@ -0,0 +1,96 @@
#ifndef USE_XNNPACK
#include <ATen/native/xnnpack/Common.h>
namespace at {
namespace native {
namespace xnnpack {
namespace internal {
namespace {
constexpr const char * const kError =
"Not Implemented! Reason: PyTorch not built with XNNPACK support.";
} // namespace
} // namespace internal
bool available() {
return false;
}
bool use_convolution2d(
const Tensor&,
const Tensor&,
const Tensor&,
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t,
const bool) {
return false;
}
Tensor convolution2d(
const Tensor&,
const Tensor&,
const Tensor&,
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t,
const bool) {
TORCH_CHECK(false, internal::kError);
}
bool use_linear(
const Tensor&,
const Tensor&,
const Tensor&) {
return false;
}
Tensor linear(
const Tensor&,
const Tensor&,
const Tensor&) {
TORCH_CHECK(false, internal::kError);
}
} // namespace xnnpack
at::Tensor _conv2d_prepack(
const Tensor&,
const Tensor&,
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t,
const c10::optional<double>,
const c10::optional<double>) {
TORCH_CHECK(false, xnnpack::internal::kError);
}
at::Tensor _conv2d_packed(
const Tensor&,
const Tensor&) {
TORCH_CHECK(false, xnnpack::internal::kError);
}
Tensor _linear_prepack(
const Tensor&,
const Tensor&,
const c10::optional<double>,
const c10::optional<double>) {
TORCH_CHECK(false, xnnpack::internal::kError);
}
Tensor _linear_packed(
const Tensor&,
const Tensor&) {
TORCH_CHECK(false, xnnpack::internal::kError);
}
} // namespace native
} // namespace at
#endif /* USE_XNNPACK */

View File

@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/utils/Allocator.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <typeinfo>
@ -478,37 +479,8 @@ QTensorImpl* get_qtensorimpl(const Tensor& self) {
// on a different page out of the process's address space.
// Here we define a custom allocator that allocates the extra storage required to keep
// this behavior safe. This same allocator can be used for FBGEMM as well.
struct QAllocator final : at::Allocator {
public:
virtual ~QAllocator() override = default;
virtual at::DataPtr allocate(size_t nbytes) const override {
Cast memory{c10::alloc_cpu(kGuard + nbytes)};
memory.as_byte_ptr += kGuard;
return {
memory.as_void_ptr,
memory.as_void_ptr,
&deleter,
at::Device(at::DeviceType::CPU)};
}
virtual at::DeleterFnPtr raw_deleter() const override {
return deleter;
}
static void deleter(void * const pointer) {
const Cast memory{pointer};
c10::free_cpu(memory.as_byte_ptr - kGuard);
}
private:
static constexpr uint32_t kGuard = 8u;
union Cast final {
void * const as_void_ptr;
uint8_t * as_byte_ptr;
};
};
using QAllocator = native::GuardingAllocator<8u, 0u>;
#endif

View File

@ -151,7 +151,6 @@ else()
message(FATAL_ERROR "Unrecognized BLAS option: " ${BLAS})
endif()
if (NOT INTERN_BUILD_MOBILE)
set(AT_MKL_ENABLED 0)
set(AT_MKL_MT 0)
@ -180,7 +179,82 @@ if (NOT INTERN_BUILD_MOBILE)
endif()
endif()
# Directory where NNPACK and cpuinfo will download and build all dependencies
# ---[ Dependencies
# NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and
# compile their dependencies in isolation as part of their build. These dependencies
# are then linked statically with PyTorch. To avoid the possibility of a version
# mismatch between these shared dependencies, explicitly declare our intent to these
# libraries that we are interested in using the exact same source dependencies for all.
if (USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK)
set(DISABLE_NNPACK_AND_FAMILY OFF)
# Sanity checks - Can we actually build NNPACK and family given the configuration provided?
# Disable them and warn the user if not.
if (IOS)
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
if (IOS_ARCH_COUNT GREATER 1)
message(WARNING
"Multi-architecture (${IOS_ARCH}) builds are not supported in {Q/X}NNPACK. "
"Specify a single architecture in IOS_ARCH and re-configure, or "
"turn this warning off by USE_{Q/X}NNPACK=OFF.")
set(DISABLE_NNPACK_AND_FAMILY ON)
endif()
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
message(WARNING
"Target architecture \"${IOS_ARCH}\" is not supported in {Q/X}NNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
set(DISABLE_NNPACK_AND_FAMILY ON)
endif()
else()
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
message(WARNING
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in {Q/X}NNPACK. "
"Supported platforms are Android, iOS, Linux, and macOS. "
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
set(DISABLE_NNPACK_AND_FAMILY ON)
endif()
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
message(WARNING
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in {Q/X}NNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
set(DISABLE_NNPACK_AND_FAMILY ON)
endif()
endif()
if (DISABLE_NNPACK_AND_FAMILY)
set(USE_NNPACK OFF)
set(USE_QNNPACK OFF)
set(USE_PYTORCH_QNNPACK OFF)
set(USE_XNNPACK OFF)
else()
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
if (NOT DEFINED CPUINFO_SOURCE_DIR)
set(CPUINFO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/cpuinfo" CACHE STRING "cpuinfo source directory")
endif()
if (NOT DEFINED FP16_SOURCE_DIR)
set(FP16_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FP16" CACHE STRING "FP16 source directory")
endif()
if (NOT DEFINED FXDIV_SOURCE_DIR)
set(FXDIV_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FXdiv" CACHE STRING "FXdiv source directory")
endif()
if (NOT DEFINED PSIMD_SOURCE_DIR)
set(PSIMD_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/psimd" CACHE STRING "PSimd source directory")
endif()
if (NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
endif()
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
endif()
endif()
set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs
CACHE PATH "Confu-style dependencies source directory")
set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps
@ -199,7 +273,7 @@ endif()
# instead of the default implementation. To avoid confusion, add pthreadpool
# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK
# does so, which will prevent it from installing the default pthreadpool library.
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK))
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK OR USE_XNNPACK))
if(NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
@ -217,69 +291,17 @@ endif()
# ---[ QNNPACK
if(USE_QNNPACK)
if (IOS)
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
if (IOS_ARCH_COUNT GREATER 1)
message(WARNING
"Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. "
"Specify a single architecture in IOS_ARCH and re-configure, or "
"turn this warning off by USE_QNNPACK=OFF.")
set(USE_QNNPACK OFF)
endif()
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
message(WARNING
"Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_QNNPACK=OFF.")
set(USE_QNNPACK OFF)
endif()
else()
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
message(WARNING
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. "
"Supported platforms are Android, iOS, Linux, and macOS. "
"Turn this warning off by USE_QNNPACK=OFF.")
set(USE_QNNPACK OFF)
endif()
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
message(WARNING
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_QNNPACK=OFF.")
set(USE_QNNPACK OFF)
endif()
endif()
if (USE_QNNPACK)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
# Directories for QNNPACK dependencies submoduled in Caffe2
if (NOT DEFINED CPUINFO_SOURCE_DIR)
set(CPUINFO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/cpuinfo" CACHE STRING "cpuinfo source directory")
endif()
if (NOT DEFINED QNNPACK_SOURCE_DIR)
set(QNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/QNNPACK" CACHE STRING "QNNPACK source directory")
endif()
if (NOT DEFINED FP16_SOURCE_DIR)
set(FP16_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FP16" CACHE STRING "FP16 source directory")
endif()
if (NOT DEFINED FXDIV_SOURCE_DIR)
set(FXDIV_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FXdiv" CACHE STRING "FXdiv source directory")
endif()
if (NOT DEFINED PSIMD_SOURCE_DIR)
set(PSIMD_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/psimd" CACHE STRING "PSimd source directory")
endif()
if (NOT DEFINED PTHREADPOOL_SOURCE_DIR)
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
endif()
if(NOT TARGET qnnpack)
set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
add_subdirectory(
"${QNNPACK_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/QNNPACK")
@ -291,7 +313,6 @@ if(USE_QNNPACK)
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS qnnpack)
endif()
endif()
# ---[ Caffe2 Int8 operators (enabled by USE_QNNPACK) depend on gemmlowp and neon2sse headers
@ -303,39 +324,6 @@ endif()
# ---[ PYTORCH_QNNPACK
if(USE_PYTORCH_QNNPACK)
if (IOS)
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
if (IOS_ARCH_COUNT GREATER 1)
message(WARNING
"Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. "
"Specify a single architecture in IOS_ARCH and re-configure, or "
"turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
set(USE_PYTORCH_QNNPACK OFF)
endif()
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
message(WARNING
"Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
set(USE_PYTORCH_QNNPACK OFF)
endif()
else()
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
message(WARNING
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. "
"Supported platforms are Android, iOS, Linux, and macOS. "
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
set(USE_PYTORCH_QNNPACK OFF)
endif()
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
message(WARNING
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. "
"Supported architectures are x86, x86-64, ARM, and ARM64. "
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
set(USE_PYTORCH_QNNPACK OFF)
endif()
endif()
if (USE_PYTORCH_QNNPACK)
if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR)
set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory")
endif()
@ -345,9 +333,6 @@ if(USE_PYTORCH_QNNPACK)
set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
add_subdirectory(
"${PYTORCH_QNNPACK_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack")
@ -359,7 +344,6 @@ if(USE_PYTORCH_QNNPACK)
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack)
endif()
endif()
# ---[ NNPACK
@ -379,6 +363,33 @@ if(USE_NNPACK)
endif()
endif()
# ---[ XNNPACK
if(USE_XNNPACK)
if (NOT DEFINED XNNPACK_SOURCE_DIR)
set(XNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/XNNPACK" CACHE STRING "XNNPACK source directory")
endif()
if (NOT DEFINED XNNPACK_INCLUDE_DIR)
set(XNNPACK_INCLUDE_DIR "${XNNPACK_SOURCE_DIR}/include" CACHE STRING "XNNPACK include directory")
endif()
if(NOT TARGET XNNPACK)
set(XNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
set(XNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
set(XNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
set(XNNPACK_BUILD_TESTS OFF CACHE BOOL "")
add_subdirectory(
"${XNNPACK_SOURCE_DIR}"
"${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK")
set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})
list(APPEND Caffe2_DEPENDENCY_LIBS XNNPACK)
endif()
# ---[ Caffe2 uses cpuinfo library in the thread pool
if (NOT TARGET cpuinfo)
if (NOT DEFINED CPUINFO_SOURCE_DIR)

View File

@ -64,6 +64,11 @@ if (NOT @BUILD_SHARED_LIBS@)
list(APPEND TORCH_LIBRARIES ${PYTORCH_QNNPACK_LIBRARY})
endif()
if (@USE_XNNPACK@)
find_library(XNNPACK_LIBRARY XNNPACK PATHS "${TORCH_INSTALL_PREFIX}/lib")
list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY})
endif()
if (@INTERN_USE_EIGEN_BLAS@)
find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib")
list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY})

View File

@ -63,7 +63,7 @@ targets.each do |target|
target.resources_build_phase.add_file_reference(config_file_ref, true)
end
puts "Linking static libraries..."
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
targets.each do |target|
target.frameworks_build_phases.clear
for lib in libs do

View File

@ -51,7 +51,7 @@ end
# link static libraries
target.frameworks_build_phases.clear
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
for lib in libs do
path = "#{install_path}/lib/#{lib}"
if File.exist?(path)

1
third_party/XNNPACK vendored Submodule

Submodule third_party/XNNPACK added at fa611cc5c2