mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Mobile Backend: NHWC memory layout + XNNPACK integration. (#33722)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33722 In order to improve CPU performance on floating-point models on mobile, this PR introduces a new CPU backend for mobile that implements the most common mobile operators with NHWC memory layout support through integration with XNNPACK. XNNPACK itself, and this codepath, are currently only included in the build, but the actual integration is gated with USE_XNNPACK preprocessor guards. This preprocessor symbol is intentionally not passed on to the compiler, so as to enable this rollout in multiple stages in follow up PRs. This changeset will build XNNPACK as part of the build if the identically named USE_XNNPACK CMAKE variable, defaulted to ON, is enabled, but will not actually expose or enable this code path in any other way. Furthermore, it is worth pointing out that in order to efficiently map models to these operators, some front-end method of exposing this backend to the user is needed. The less efficient implementation would be to hook these operators into their corresponding native implementations, granted that a series of XNNPACK-specific conditions are met, much like how NNPACK is integrated with PyTorch today for instance. Having said that, while the above implementation is still expected to outperform NNPACK based on the benchmarks I ran, the above integration would be leave a considerable gap between the performance achieved and the maximum performance potential XNNPACK enables, as it does not provide a way to compute and factor out one-time operations out of the inner most forward() loop. The more optimal solution, and one we will decide on soon, would involve either providing a JIT pass that maps nn operators onto these newly introduced operators, while allowing one-time calculations to be factored out, much like quantized mobile models. Alternatively, new eager-mode modules can also be introduced that would directly call into these implementations either through c10 or some other mechanism, also allowing for decoupling of op creation from op execution. This PR does not include any of the front end changes mentioned above. Neither does it include the mobile threadpool unification present in the original https://github.com/pytorch/pytorch/issues/30644. Furthermore, this codepath seems to be faster than NNPACK in a good number of use cases, which can potentially allow us to remove NNPACK from aten to make the codebase a little simpler, granted that there is widespread support for such a move. Regardless, these changes will be introduced gradually and in a more controlled way in subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32509 Test Plan: Build: CI Functionality: Not exposed Reviewed By: dreiss Differential Revision: D20069796 Pulled By: AshkanAliabadi fbshipit-source-id: d46c1c91d4bea91979ea5bd46971ced5417d309c
This commit is contained in:
committed by
Facebook Github Bot
parent
2a4aad7466
commit
6aecfd1e80
@ -14,7 +14,7 @@ mkdir -p ${ZIP_DIR}/src
|
||||
cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/
|
||||
# build a FAT bianry
|
||||
cd ${ZIP_DIR}/install/lib
|
||||
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a)
|
||||
target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a)
|
||||
for lib in ${target_libs[*]}
|
||||
do
|
||||
if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then
|
||||
|
10
.gitmodules
vendored
10
.gitmodules
vendored
@ -111,10 +111,14 @@
|
||||
path = third_party/foxi
|
||||
url = https://github.com/houseroad/foxi.git
|
||||
[submodule "third_party/tbb"]
|
||||
path = third_party/tbb
|
||||
url = https://github.com/01org/tbb
|
||||
branch = tbb_2018
|
||||
path = third_party/tbb
|
||||
url = https://github.com/01org/tbb
|
||||
branch = tbb_2018
|
||||
[submodule "android/libs/fbjni"]
|
||||
ignore = dirty
|
||||
path = android/libs/fbjni
|
||||
url = https://github.com/facebookincubator/fbjni.git
|
||||
[submodule "third_party/XNNPACK"]
|
||||
path = third_party/XNNPACK
|
||||
url = https://github.com/AshkanAliabadi/XNNPACK.git
|
||||
branch = xnnpack_pytorch_merge_temp
|
||||
|
@ -185,6 +185,7 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
|
||||
option(USE_SYSTEM_EIGEN_INSTALL
|
||||
"Use system Eigen instead of the one under third_party" OFF)
|
||||
option(USE_TENSORRT "Using Nvidia TensorRT library" OFF)
|
||||
option(USE_XNNPACK "Use XNNPACK" ON)
|
||||
option(USE_ZMQ "Use ZMQ" OFF)
|
||||
option(USE_ZSTD "Use ZSTD" OFF)
|
||||
cmake_dependent_option(
|
||||
@ -416,6 +417,10 @@ if(USE_PYTORCH_QNNPACK)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PYTORCH_QNNPACK")
|
||||
endif()
|
||||
|
||||
if(USE_XNNPACK)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK")
|
||||
endif()
|
||||
|
||||
# ---[ Whitelist file if whitelist is specified
|
||||
include(cmake/Whitelist.cmake)
|
||||
|
||||
|
@ -83,6 +83,7 @@ if (ANDROID_ABI)
|
||||
import_static_lib(libtorch_cpu)
|
||||
import_static_lib(libc10)
|
||||
import_static_lib(libnnpack)
|
||||
import_static_lib(libXNNPACK)
|
||||
import_static_lib(libpytorch_qnnpack)
|
||||
import_static_lib(libeigen_blas)
|
||||
import_static_lib(libcpuinfo)
|
||||
@ -98,6 +99,7 @@ if (ANDROID_ABI)
|
||||
-Wl,--no-whole-archive
|
||||
libc10
|
||||
libnnpack
|
||||
libXNNPACK
|
||||
libpytorch_qnnpack
|
||||
libeigen_blas
|
||||
libcpuinfo
|
||||
@ -113,6 +115,7 @@ else()
|
||||
torch_cpu
|
||||
c10
|
||||
nnpack
|
||||
XNNPACK
|
||||
pytorch_qnnpack
|
||||
cpuinfo
|
||||
clog
|
||||
|
@ -84,8 +84,11 @@ FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
|
||||
FILE(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
|
||||
FILE(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
|
||||
|
||||
# XNNPACK
|
||||
FILE(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
||||
|
||||
add_subdirectory(quantized)
|
||||
set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
|
||||
set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
|
||||
if(AT_MKL_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
|
||||
endif()
|
||||
|
@ -775,6 +775,10 @@
|
||||
|
||||
- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor
|
||||
|
||||
- func: _conv2d_prepack(Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1, float? output_min=None, float? output_max=None) -> Tensor
|
||||
|
||||
- func: _conv2d_packed(Tensor packed_weight, Tensor input) -> Tensor
|
||||
|
||||
- func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor
|
||||
|
||||
- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
|
||||
@ -1575,6 +1579,10 @@
|
||||
- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
||||
python_module: nn
|
||||
|
||||
- func: _linear_prepack(Tensor weight, Tensor? bias=None, float? output_min=None, float? output_max=None) -> Tensor
|
||||
|
||||
- func: _linear_packed(Tensor packed_weight, Tensor input) -> Tensor
|
||||
|
||||
- func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
|
55
aten/src/ATen/native/utils/Allocator.h
Normal file
55
aten/src/ATen/native/utils/Allocator.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// QNNPACK AND XNNPACK may out-of-bound access the input and / or output tensors.
|
||||
// This behavior will trigger ASAN, and may result in a segfault if the accessed
|
||||
// memory just so happens to fall on a page the current process has no read access
|
||||
// to. Here we define a custom allocator that allocates the extra storage required
|
||||
// to keep this behavior safe.
|
||||
//
|
||||
// PreGuardBytes: Number of guard bytes to allocate before the allocation.
|
||||
// PostGuardBytes: Number of guard bytes to allocate after the allocation.
|
||||
|
||||
template <uint32_t PreGuardBytes, uint32_t PostGuardBytes>
|
||||
class GuardingAllocator final : public at::Allocator {
|
||||
public:
|
||||
GuardingAllocator() = default;
|
||||
virtual ~GuardingAllocator() override = default;
|
||||
|
||||
static void deleter(void* pointer) {
|
||||
const Cast memory{pointer};
|
||||
c10::free_cpu(memory.as_byte_ptr - kPreGuardBytes);
|
||||
}
|
||||
|
||||
virtual DataPtr allocate(size_t nbytes) const override {
|
||||
Cast memory{c10::alloc_cpu(kPreGuardBytes + nbytes + kPostGuardBytes)};
|
||||
memory.as_byte_ptr += kPreGuardBytes;
|
||||
|
||||
return {
|
||||
memory.as_void_ptr,
|
||||
memory.as_void_ptr,
|
||||
&deleter,
|
||||
at::Device(DeviceType::CPU),
|
||||
};
|
||||
}
|
||||
|
||||
virtual DeleterFnPtr raw_deleter() const override {
|
||||
return deleter;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr uint32_t kPreGuardBytes = PreGuardBytes;
|
||||
static constexpr uint32_t kPostGuardBytes = PostGuardBytes;
|
||||
|
||||
union Cast final {
|
||||
void * const as_void_ptr;
|
||||
uint8_t * as_byte_ptr;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
82
aten/src/ATen/native/xnnpack/Common.h
Normal file
82
aten/src/ATen/native/xnnpack/Common.h
Normal file
@ -0,0 +1,82 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <xnnpack.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
|
||||
struct Layout final {
|
||||
// 4D Activation Maps
|
||||
struct Activation4D final {
|
||||
static constexpr size_t batch = 0u;
|
||||
static constexpr size_t channels = 1u;
|
||||
static constexpr size_t height = 2u;
|
||||
static constexpr size_t width = 3u;
|
||||
};
|
||||
|
||||
// ND Activation Maps
|
||||
struct ActivationND final {
|
||||
// Some operators may not be limited to 4 dimensional tensors. In that scenario,
|
||||
// XNNPACK denotes that operator with an _nc suffix and expects all dimensions,
|
||||
// except channels, to be flattened into one argument: batch_size.
|
||||
static int64_t batch(const IntArrayRef tensor) {
|
||||
if (C10_UNLIKELY(tensor.empty())) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Handle the case where batch size is zero.
|
||||
int64_t batch = std::max<int64_t>(1, tensor[0]);
|
||||
|
||||
for (size_t index = 1u; index < (tensor.size() - 1u); ++index) {
|
||||
batch *= tensor[index];
|
||||
}
|
||||
|
||||
return batch;
|
||||
};
|
||||
|
||||
static int64_t channel(const IntArrayRef tensor) {
|
||||
if (C10_UNLIKELY(tensor.empty())) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return tensor.back();
|
||||
};
|
||||
};
|
||||
|
||||
// Convolution Filters
|
||||
struct Filter final {
|
||||
static constexpr size_t output = 0u;
|
||||
static constexpr size_t input = 1u;
|
||||
static constexpr size_t height = 2u;
|
||||
static constexpr size_t width = 3u;
|
||||
};
|
||||
|
||||
// Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
|
||||
struct Parameter final {
|
||||
static constexpr size_t height = 0u;
|
||||
static constexpr size_t width = 1u;
|
||||
};
|
||||
};
|
||||
|
||||
struct Deleter final {
|
||||
void operator()(const xnn_operator_t op) const {
|
||||
xnn_delete_operator(op);
|
||||
}
|
||||
};
|
||||
|
||||
using Operator = std::unique_ptr<xnn_operator, Deleter>;
|
||||
|
||||
bool available();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace xnnpack
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_XNNPACK */
|
316
aten/src/ATen/native/xnnpack/Convolution.cpp
Normal file
316
aten/src/ATen/native/xnnpack/Convolution.cpp
Normal file
@ -0,0 +1,316 @@
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/native/xnnpack/Common.h>
|
||||
#include <ATen/native/xnnpack/Factory.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
namespace convolution2d {
|
||||
|
||||
struct Context final {
|
||||
Operator convolution_op;
|
||||
|
||||
std::vector<int64_t> weight_size;
|
||||
std::vector<int64_t> padding;
|
||||
std::vector<int64_t> stride;
|
||||
std::vector<int64_t> dilation;
|
||||
|
||||
static constexpr float kMin = -std::numeric_limits<float>::infinity();
|
||||
static constexpr float kMax = std::numeric_limits<float>::infinity();
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Supports NHWC and NCHW FP32 convolutions with any valid
|
||||
// - kernel size
|
||||
// - padding
|
||||
// - stride
|
||||
// - dilation
|
||||
// - grouping
|
||||
|
||||
// TODO: Decouple and improve error handling and messages.
|
||||
bool available(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
// XNNPACK
|
||||
return xnnpack::internal::available() &&
|
||||
// Weight
|
||||
(4 == weight.ndimension()) &&
|
||||
(weight.size(Layout::Filter::height) > 0) &&
|
||||
(weight.size(Layout::Filter::width) > 0) &&
|
||||
(c10::DeviceType::CPU == weight.device().type()) &&
|
||||
(kFloat == weight.scalar_type()) &&
|
||||
// Bias
|
||||
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
|
||||
(c10::DeviceType::CPU == bias->device().type()) &&
|
||||
(kFloat == bias->scalar_type()) &&
|
||||
(weight.size(Layout::Filter::output)) == bias->size(0))
|
||||
: true) &&
|
||||
// Padding
|
||||
(padding[Layout::Parameter::height] >= 0) &&
|
||||
(padding[Layout::Parameter::width] >= 0) &&
|
||||
// Stride
|
||||
(stride[Layout::Parameter::height] > 0) &&
|
||||
(stride[Layout::Parameter::width] > 0) &&
|
||||
// Dilation
|
||||
(dilation[Layout::Parameter::height] > 0) &&
|
||||
(dilation[Layout::Parameter::width] > 0) &&
|
||||
// Groups
|
||||
(groups > 0) &&
|
||||
// Input
|
||||
(weight.size(Layout::Filter::input) > 0) &&
|
||||
// Output
|
||||
(weight.size(Layout::Filter::output) > 0) &&
|
||||
// Output - Groups
|
||||
((weight.size(Layout::Filter::output) % groups) == 0) &&
|
||||
// Output Min / Max
|
||||
(output_max > output_min) &&
|
||||
true;
|
||||
}
|
||||
|
||||
Context create(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const IntArrayRef padding_,
|
||||
const IntArrayRef stride_,
|
||||
const IntArrayRef dilation_,
|
||||
const int64_t groups,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
const auto padding = expand_param_if_needed(padding_, "padding", 2);
|
||||
const auto stride = expand_param_if_needed(stride_, "stride", 2);
|
||||
const auto dilation = expand_param_if_needed(dilation_, "dilation", 2);
|
||||
const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast);
|
||||
|
||||
TORCH_CHECK(
|
||||
available(
|
||||
weight_nhwc,
|
||||
bias,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
output_min,
|
||||
output_max),
|
||||
"XNNPACK Convolution not available! "
|
||||
"Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) "
|
||||
"parameters are either invalid individually or their combination is not supported by XNNPACK.");
|
||||
|
||||
xnn_operator_t convolution_op{};
|
||||
|
||||
const xnn_status create_status = xnn_create_convolution2d_nhwc_f32(
|
||||
padding[Layout::Parameter::height], // input_padding_top
|
||||
padding[Layout::Parameter::width], // input_padding_right
|
||||
padding[Layout::Parameter::height], // input_padding_bottom
|
||||
padding[Layout::Parameter::width], // input_padding_left
|
||||
weight_nhwc.size(Layout::Filter::height), // kernel_height
|
||||
weight_nhwc.size(Layout::Filter::width), // kernel_width
|
||||
stride[Layout::Parameter::height], // subsampling_height
|
||||
stride[Layout::Parameter::width], // subsampling_width
|
||||
dilation[Layout::Parameter::height], // dilation_height
|
||||
dilation[Layout::Parameter::width], // dilation_width
|
||||
groups, // groups
|
||||
weight_nhwc.size(Layout::Filter::input), // group_input_channels
|
||||
weight_nhwc.size(Layout::Filter::output) / groups, // group_output_channels
|
||||
weight_nhwc.size(Layout::Filter::input) * groups, // input_pixel_stride
|
||||
weight_nhwc.size(Layout::Filter::output), // output_pixel_stride
|
||||
weight_nhwc.data_ptr<float>(), // kernel
|
||||
(bias && bias->defined()) ? bias->data_ptr<float>() : nullptr, // bias
|
||||
output_min, // output_min
|
||||
output_max, // output_max
|
||||
0u, // flags
|
||||
&convolution_op); // operator
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == create_status,
|
||||
"xnn_create_convolution2d_nhwc_f32 failed!");
|
||||
|
||||
return Context{
|
||||
Operator(convolution_op),
|
||||
weight_nhwc.sizes().vec(),
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: Decouple and improve error handling and messages.
|
||||
bool usable(const Tensor& input) {
|
||||
// Input
|
||||
return (4 == input.ndimension()) &&
|
||||
(c10::DeviceType::CPU == input.device().type()) &&
|
||||
(kFloat == input.scalar_type()) &&
|
||||
(input.size(Layout::Activation4D::batch) > 0) &&
|
||||
(input.size(Layout::Activation4D::channels) > 0) &&
|
||||
(input.size(Layout::Activation4D::height) > 0) &&
|
||||
(input.size(Layout::Activation4D::width) > 0) &&
|
||||
true;
|
||||
}
|
||||
|
||||
Tensor run(
|
||||
const Context& context,
|
||||
const Tensor& input) {
|
||||
using namespace internal;
|
||||
|
||||
const Tensor input_nhwc = input.contiguous(MemoryFormat::ChannelsLast);
|
||||
|
||||
TORCH_CHECK(
|
||||
usable(input_nhwc),
|
||||
"XNNPACK Convolution not usable! "
|
||||
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
|
||||
|
||||
Tensor output = empty_with_tail_padding(
|
||||
conv_output_size(
|
||||
input_nhwc.sizes(),
|
||||
context.weight_size,
|
||||
context.padding,
|
||||
context.stride,
|
||||
context.dilation),
|
||||
input_nhwc.options().dtype(),
|
||||
MemoryFormat::ChannelsLast);
|
||||
|
||||
const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32(
|
||||
context.convolution_op.get(), // operator
|
||||
input_nhwc.size(Layout::Activation4D::batch), // batch_size
|
||||
input_nhwc.size(Layout::Activation4D::height), // input_height
|
||||
input_nhwc.size(Layout::Activation4D::width), // input_width
|
||||
input_nhwc.data_ptr<float>(), // input
|
||||
output.data_ptr<float>(), // output
|
||||
nullptr); // threadpool
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == setup_status,
|
||||
"xnn_setup_convolution2d_nhwc_f32 failed!");
|
||||
|
||||
const xnn_status run_status = xnn_run_operator(
|
||||
context.convolution_op.get(), // operator
|
||||
nullptr); // threadpool
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
xnn_status_success == run_status,
|
||||
"xnn_run_operator failed!");
|
||||
|
||||
return output.contiguous(input.suggest_memory_format());
|
||||
}
|
||||
|
||||
Tensor create_and_run(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
return run(
|
||||
create(
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
output_min,
|
||||
output_max),
|
||||
input);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace convolution2d
|
||||
} // namespace internal
|
||||
|
||||
bool use_convolution2d(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups) {
|
||||
return internal::convolution2d::available(
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
internal::convolution2d::Context::kMin,
|
||||
internal::convolution2d::Context::kMax) &&
|
||||
internal::convolution2d::usable(input);
|
||||
}
|
||||
|
||||
Tensor convolution2d(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups) {
|
||||
return internal::convolution2d::create_and_run(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
internal::convolution2d::Context::kMin,
|
||||
internal::convolution2d::Context::kMax);
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
|
||||
at::Tensor _conv2d_prepack(
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const c10::optional<double> output_min,
|
||||
const c10::optional<double> output_max) {
|
||||
return cpp_custom_type_hack::create(
|
||||
std::make_unique<xnnpack::internal::convolution2d::Context>(
|
||||
xnnpack::internal::convolution2d::create(
|
||||
weight,
|
||||
bias,
|
||||
padding.vec(),
|
||||
stride.vec(),
|
||||
dilation.vec(),
|
||||
groups,
|
||||
output_min ? *output_min : xnnpack::internal::convolution2d::Context::kMin,
|
||||
output_max ? *output_max : xnnpack::internal::convolution2d::Context::kMax)),
|
||||
weight.options());
|
||||
}
|
||||
|
||||
at::Tensor _conv2d_packed(
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& input) {
|
||||
return xnnpack::internal::convolution2d::run(
|
||||
cpp_custom_type_hack::cast<xnnpack::internal::convolution2d::Context>(packed_weight),
|
||||
input);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::convolution2d::Context);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif /* USE_XNNPACK */
|
38
aten/src/ATen/native/xnnpack/Factory.cpp
Normal file
38
aten/src/ATen/native/xnnpack/Factory.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <ATen/native/xnnpack/Factory.h>
|
||||
#include <ATen/native/utils/Allocator.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
|
||||
Tensor empty_with_tail_padding(
|
||||
const IntArrayRef size,
|
||||
const caffe2::TypeMeta dtype,
|
||||
const c10::MemoryFormat memory_format) {
|
||||
static GuardingAllocator<0u, XNN_EXTRA_BYTES> allocator;
|
||||
|
||||
const int64_t nelements = prod_intlist(size);
|
||||
|
||||
Tensor tensor(
|
||||
c10::make_intrusive<c10::TensorImpl>(
|
||||
c10::Storage{
|
||||
dtype,
|
||||
nelements,
|
||||
allocator.allocate(nelements * dtype.itemsize()),
|
||||
&allocator,
|
||||
/*resizable=*/true,
|
||||
},
|
||||
DispatchKeySet{DispatchKey::CPUTensorId}));
|
||||
|
||||
return tensor.resize_(size, memory_format);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace xnnpack
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_XNNPACK */
|
25
aten/src/ATen/native/xnnpack/Factory.h
Normal file
25
aten/src/ATen/native/xnnpack/Factory.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/xnnpack/Common.h>
|
||||
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
|
||||
// TODO: Remove this function when at::native::empty() is modified to accept a
|
||||
// custom memory allocator.
|
||||
|
||||
at::Tensor empty_with_tail_padding(
|
||||
IntArrayRef size,
|
||||
const caffe2::TypeMeta dtype,
|
||||
c10::MemoryFormat memory_format);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace xnnpack
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_XNNPACK */
|
63
aten/src/ATen/native/xnnpack/Init.cpp
Normal file
63
aten/src/ATen/native/xnnpack/Init.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <ATen/native/xnnpack/Common.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
bool is_initialized_ = false;
|
||||
|
||||
bool initialize() {
|
||||
using namespace internal;
|
||||
|
||||
// This implementation allows for retries.
|
||||
if (!is_initialized_) {
|
||||
const xnn_status status = xnn_initialize(nullptr);
|
||||
is_initialized_ = (xnn_status_success == status);
|
||||
|
||||
if (!is_initialized_) {
|
||||
if (xnn_status_out_of_memory == status) {
|
||||
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Out of memory.");
|
||||
} else if (xnn_status_unsupported_hardware == status) {
|
||||
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unsupported hardware.");
|
||||
} else {
|
||||
TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unknown error!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return is_initialized_;
|
||||
}
|
||||
|
||||
bool deinitialize() {
|
||||
using namespace internal;
|
||||
|
||||
// This implementation allows for retries.
|
||||
if (is_initialized_) {
|
||||
const xnn_status status = xnn_deinitialize();
|
||||
is_initialized_ = !(xnn_status_success == status);
|
||||
|
||||
if (is_initialized_) {
|
||||
TORCH_WARN_ONCE("Failed to uninitialize XNNPACK! Reason: Unknown error!");
|
||||
}
|
||||
}
|
||||
|
||||
return !is_initialized_;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool available() {
|
||||
// Add extra conditions here that should disable mobile CPU impl at runtime in its totality.
|
||||
return internal::initialize();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace xnnpack
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_XNNPACK */
|
223
aten/src/ATen/native/xnnpack/Linear.cpp
Normal file
223
aten/src/ATen/native/xnnpack/Linear.cpp
Normal file
@ -0,0 +1,223 @@
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/xnnpack/Common.h>
|
||||
#include <ATen/native/xnnpack/Factory.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
namespace linear {
|
||||
|
||||
struct Context final {
|
||||
Operator linear_op;
|
||||
|
||||
struct Output final {
|
||||
int64_t channels;
|
||||
} output;
|
||||
|
||||
static constexpr float kMin = -std::numeric_limits<float>::infinity();
|
||||
static constexpr float kMax = std::numeric_limits<float>::infinity();
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Supports NHWC and NCHW FP32 linear operators.
|
||||
|
||||
// TODO: Decouple and improve error handling and messages.
|
||||
bool available(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
// XNNPACK
|
||||
return xnnpack::internal::available() &&
|
||||
// Weight
|
||||
(2 == weight.ndimension()) &&
|
||||
(c10::DeviceType::CPU == weight.device().type()) &&
|
||||
(kFloat == weight.scalar_type()) &&
|
||||
// Bias
|
||||
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
|
||||
(c10::DeviceType::CPU == bias->device().type()) &&
|
||||
(kFloat == bias->scalar_type()) &&
|
||||
(weight.size(Layout::Filter::output)) == bias->size(0))
|
||||
: true) &&
|
||||
// Output Min / Max
|
||||
(output_max > output_min) &&
|
||||
true;
|
||||
}
|
||||
|
||||
Context create(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
const Tensor weight_contig = weight.contiguous();
|
||||
|
||||
TORCH_CHECK(
|
||||
available(
|
||||
weight_contig,
|
||||
bias,
|
||||
output_min,
|
||||
output_max),
|
||||
"XNNPACK Linear not available! "
|
||||
"Reason: The provided (weight, bias, output_min, output_max) parameters are "
|
||||
"either invalid individually or their combination is not supported by XNNPACK.");
|
||||
|
||||
xnn_operator_t linear_op{};
|
||||
|
||||
const xnn_status create_status = xnn_create_fully_connected_nc_f32(
|
||||
weight_contig.size(Layout::Filter::input), // input_channels
|
||||
weight_contig.size(Layout::Filter::output), // output_channels
|
||||
weight_contig.size(Layout::Filter::input), // input_pixel_stride
|
||||
weight_contig.size(Layout::Filter::output), // output_pixel_stride
|
||||
weight_contig.data_ptr<float>(), // kernel
|
||||
(bias && bias->defined()) ? bias->data_ptr<float>() : nullptr, // bias
|
||||
output_min, // output_min
|
||||
output_max, // output_max
|
||||
0u, // flags
|
||||
&linear_op); // operator
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == create_status,
|
||||
"xnn_create_fully_connected_nc_f32 failed!");
|
||||
|
||||
return Context{
|
||||
Operator(linear_op),
|
||||
{
|
||||
weight_contig.size(Layout::Filter::output),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: Decouple and improve error handling and messages.
|
||||
bool usable(const Tensor& input) {
|
||||
// Input
|
||||
return (2 <= input.ndimension()) &&
|
||||
(c10::DeviceType::CPU == input.device().type()) &&
|
||||
(kFloat == input.scalar_type()) &&
|
||||
true;
|
||||
}
|
||||
|
||||
Tensor run(
|
||||
const Context& context,
|
||||
const Tensor& input) {
|
||||
using namespace internal;
|
||||
|
||||
const Tensor& input_contig = input.contiguous();
|
||||
|
||||
TORCH_CHECK(
|
||||
usable(input_contig),
|
||||
"XNNPACK Linear not usable! "
|
||||
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
|
||||
|
||||
const IntArrayRef input_size = input_contig.sizes();
|
||||
std::vector<int64_t> output_size(input_size.cbegin(), input_size.cend());
|
||||
output_size.back() = context.output.channels;
|
||||
|
||||
Tensor output = empty_with_tail_padding(
|
||||
output_size,
|
||||
input_contig.options().dtype(),
|
||||
input_contig.suggest_memory_format());
|
||||
|
||||
const xnn_status setup_status = xnn_setup_fully_connected_nc_f32(
|
||||
context.linear_op.get(), // operator
|
||||
Layout::ActivationND::batch(input_contig.sizes()), // Batch,
|
||||
input_contig.data_ptr<float>(), // input
|
||||
output.data_ptr<float>(), // output
|
||||
nullptr); // threadpool
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == setup_status,
|
||||
"xnn_setup_fully_connected_nc_f32 failed!");
|
||||
|
||||
const xnn_status run_status = xnn_run_operator(
|
||||
context.linear_op.get(), // operator
|
||||
nullptr); // threadpool
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
xnn_status_success == run_status,
|
||||
"xnn_run_operator failed!");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor create_and_run(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
return run(
|
||||
create(
|
||||
weight,
|
||||
bias,
|
||||
output_min,
|
||||
output_max),
|
||||
input);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace linear
|
||||
} // namespace internal
|
||||
|
||||
bool use_linear(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias) {
|
||||
return internal::linear::available(
|
||||
weight,
|
||||
bias,
|
||||
internal::linear::Context::kMin,
|
||||
internal::linear::Context::kMax) &&
|
||||
internal::linear::usable(input);
|
||||
}
|
||||
|
||||
Tensor linear(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias) {
|
||||
return internal::linear::create_and_run(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
internal::linear::Context::kMin,
|
||||
internal::linear::Context::kMax);
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
|
||||
Tensor _linear_prepack(
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const c10::optional<double> output_min,
|
||||
const c10::optional<double> output_max) {
|
||||
return cpp_custom_type_hack::create(
|
||||
std::make_unique<xnnpack::internal::linear::Context>(
|
||||
xnnpack::internal::linear::create(
|
||||
weight,
|
||||
bias,
|
||||
output_min ? *output_min : xnnpack::internal::linear::Context::kMin,
|
||||
output_max ? *output_max : xnnpack::internal::linear::Context::kMax)),
|
||||
weight.options());
|
||||
}
|
||||
|
||||
Tensor _linear_packed(
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& input) {
|
||||
return xnnpack::internal::linear::run(
|
||||
cpp_custom_type_hack::cast<xnnpack::internal::linear::Context>(packed_weight),
|
||||
input);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::linear::Context);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif /* USE_XNNPACK */
|
96
aten/src/ATen/native/xnnpack/Shim.cpp
Normal file
96
aten/src/ATen/native/xnnpack/Shim.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
#ifndef USE_XNNPACK
|
||||
|
||||
#include <ATen/native/xnnpack/Common.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
constexpr const char * const kError =
|
||||
"Not Implemented! Reason: PyTorch not built with XNNPACK support.";
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
|
||||
bool available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool use_convolution2d(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const int64_t,
|
||||
const bool) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor convolution2d(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const int64_t,
|
||||
const bool) {
|
||||
TORCH_CHECK(false, internal::kError);
|
||||
}
|
||||
|
||||
bool use_linear(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor linear(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&) {
|
||||
TORCH_CHECK(false, internal::kError);
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
|
||||
at::Tensor _conv2d_prepack(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const int64_t,
|
||||
const c10::optional<double>,
|
||||
const c10::optional<double>) {
|
||||
TORCH_CHECK(false, xnnpack::internal::kError);
|
||||
}
|
||||
|
||||
at::Tensor _conv2d_packed(
|
||||
const Tensor&,
|
||||
const Tensor&) {
|
||||
TORCH_CHECK(false, xnnpack::internal::kError);
|
||||
}
|
||||
|
||||
Tensor _linear_prepack(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const c10::optional<double>,
|
||||
const c10::optional<double>) {
|
||||
TORCH_CHECK(false, xnnpack::internal::kError);
|
||||
}
|
||||
|
||||
Tensor _linear_packed(
|
||||
const Tensor&,
|
||||
const Tensor&) {
|
||||
TORCH_CHECK(false, xnnpack::internal::kError);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_XNNPACK */
|
@ -5,6 +5,7 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/utils/Allocator.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <typeinfo>
|
||||
@ -478,37 +479,8 @@ QTensorImpl* get_qtensorimpl(const Tensor& self) {
|
||||
// on a different page out of the process's address space.
|
||||
// Here we define a custom allocator that allocates the extra storage required to keep
|
||||
// this behavior safe. This same allocator can be used for FBGEMM as well.
|
||||
struct QAllocator final : at::Allocator {
|
||||
public:
|
||||
virtual ~QAllocator() override = default;
|
||||
|
||||
virtual at::DataPtr allocate(size_t nbytes) const override {
|
||||
Cast memory{c10::alloc_cpu(kGuard + nbytes)};
|
||||
memory.as_byte_ptr += kGuard;
|
||||
return {
|
||||
memory.as_void_ptr,
|
||||
memory.as_void_ptr,
|
||||
&deleter,
|
||||
at::Device(at::DeviceType::CPU)};
|
||||
}
|
||||
|
||||
virtual at::DeleterFnPtr raw_deleter() const override {
|
||||
return deleter;
|
||||
}
|
||||
|
||||
static void deleter(void * const pointer) {
|
||||
const Cast memory{pointer};
|
||||
c10::free_cpu(memory.as_byte_ptr - kGuard);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr uint32_t kGuard = 8u;
|
||||
|
||||
union Cast final {
|
||||
void * const as_void_ptr;
|
||||
uint8_t * as_byte_ptr;
|
||||
};
|
||||
};
|
||||
using QAllocator = native::GuardingAllocator<8u, 0u>;
|
||||
|
||||
#endif
|
||||
|
||||
|
@ -151,7 +151,6 @@ else()
|
||||
message(FATAL_ERROR "Unrecognized BLAS option: " ${BLAS})
|
||||
endif()
|
||||
|
||||
|
||||
if (NOT INTERN_BUILD_MOBILE)
|
||||
set(AT_MKL_ENABLED 0)
|
||||
set(AT_MKL_MT 0)
|
||||
@ -180,7 +179,82 @@ if (NOT INTERN_BUILD_MOBILE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Directory where NNPACK and cpuinfo will download and build all dependencies
|
||||
# ---[ Dependencies
|
||||
# NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and
|
||||
# compile their dependencies in isolation as part of their build. These dependencies
|
||||
# are then linked statically with PyTorch. To avoid the possibility of a version
|
||||
# mismatch between these shared dependencies, explicitly declare our intent to these
|
||||
# libraries that we are interested in using the exact same source dependencies for all.
|
||||
|
||||
if (USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK)
|
||||
set(DISABLE_NNPACK_AND_FAMILY OFF)
|
||||
|
||||
# Sanity checks - Can we actually build NNPACK and family given the configuration provided?
|
||||
# Disable them and warn the user if not.
|
||||
|
||||
if (IOS)
|
||||
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
|
||||
if (IOS_ARCH_COUNT GREATER 1)
|
||||
message(WARNING
|
||||
"Multi-architecture (${IOS_ARCH}) builds are not supported in {Q/X}NNPACK. "
|
||||
"Specify a single architecture in IOS_ARCH and re-configure, or "
|
||||
"turn this warning off by USE_{Q/X}NNPACK=OFF.")
|
||||
set(DISABLE_NNPACK_AND_FAMILY ON)
|
||||
endif()
|
||||
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
|
||||
message(WARNING
|
||||
"Target architecture \"${IOS_ARCH}\" is not supported in {Q/X}NNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
|
||||
set(DISABLE_NNPACK_AND_FAMILY ON)
|
||||
endif()
|
||||
else()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
|
||||
message(WARNING
|
||||
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in {Q/X}NNPACK. "
|
||||
"Supported platforms are Android, iOS, Linux, and macOS. "
|
||||
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
|
||||
set(DISABLE_NNPACK_AND_FAMILY ON)
|
||||
endif()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
|
||||
message(WARNING
|
||||
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in {Q/X}NNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_{Q/X}NNPACK=OFF.")
|
||||
set(DISABLE_NNPACK_AND_FAMILY ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (DISABLE_NNPACK_AND_FAMILY)
|
||||
set(USE_NNPACK OFF)
|
||||
set(USE_QNNPACK OFF)
|
||||
set(USE_PYTORCH_QNNPACK OFF)
|
||||
set(USE_XNNPACK OFF)
|
||||
else()
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
|
||||
if (NOT DEFINED CPUINFO_SOURCE_DIR)
|
||||
set(CPUINFO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/cpuinfo" CACHE STRING "cpuinfo source directory")
|
||||
endif()
|
||||
if (NOT DEFINED FP16_SOURCE_DIR)
|
||||
set(FP16_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FP16" CACHE STRING "FP16 source directory")
|
||||
endif()
|
||||
if (NOT DEFINED FXDIV_SOURCE_DIR)
|
||||
set(FXDIV_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FXdiv" CACHE STRING "FXdiv source directory")
|
||||
endif()
|
||||
if (NOT DEFINED PSIMD_SOURCE_DIR)
|
||||
set(PSIMD_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/psimd" CACHE STRING "PSimd source directory")
|
||||
endif()
|
||||
if (NOT DEFINED PTHREADPOOL_SOURCE_DIR)
|
||||
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
|
||||
endif()
|
||||
|
||||
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
|
||||
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs
|
||||
CACHE PATH "Confu-style dependencies source directory")
|
||||
set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps
|
||||
@ -199,7 +273,7 @@ endif()
|
||||
# instead of the default implementation. To avoid confusion, add pthreadpool
|
||||
# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK
|
||||
# does so, which will prevent it from installing the default pthreadpool library.
|
||||
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK))
|
||||
if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK OR USE_XNNPACK))
|
||||
if(NOT DEFINED PTHREADPOOL_SOURCE_DIR)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
|
||||
@ -217,81 +291,28 @@ endif()
|
||||
|
||||
# ---[ QNNPACK
|
||||
if(USE_QNNPACK)
|
||||
if (IOS)
|
||||
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
|
||||
if (IOS_ARCH_COUNT GREATER 1)
|
||||
message(WARNING
|
||||
"Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. "
|
||||
"Specify a single architecture in IOS_ARCH and re-configure, or "
|
||||
"turn this warning off by USE_QNNPACK=OFF.")
|
||||
set(USE_QNNPACK OFF)
|
||||
endif()
|
||||
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
|
||||
message(WARNING
|
||||
"Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_QNNPACK=OFF.")
|
||||
set(USE_QNNPACK OFF)
|
||||
endif()
|
||||
else()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
|
||||
message(WARNING
|
||||
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. "
|
||||
"Supported platforms are Android, iOS, Linux, and macOS. "
|
||||
"Turn this warning off by USE_QNNPACK=OFF.")
|
||||
set(USE_QNNPACK OFF)
|
||||
endif()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
|
||||
message(WARNING
|
||||
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_QNNPACK=OFF.")
|
||||
set(USE_QNNPACK OFF)
|
||||
endif()
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
|
||||
if (NOT DEFINED QNNPACK_SOURCE_DIR)
|
||||
set(QNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/QNNPACK" CACHE STRING "QNNPACK source directory")
|
||||
endif()
|
||||
if (USE_QNNPACK)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
|
||||
# Directories for QNNPACK dependencies submoduled in Caffe2
|
||||
if (NOT DEFINED CPUINFO_SOURCE_DIR)
|
||||
set(CPUINFO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/cpuinfo" CACHE STRING "cpuinfo source directory")
|
||||
endif()
|
||||
if (NOT DEFINED QNNPACK_SOURCE_DIR)
|
||||
set(QNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/QNNPACK" CACHE STRING "QNNPACK source directory")
|
||||
endif()
|
||||
if (NOT DEFINED FP16_SOURCE_DIR)
|
||||
set(FP16_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FP16" CACHE STRING "FP16 source directory")
|
||||
endif()
|
||||
if (NOT DEFINED FXDIV_SOURCE_DIR)
|
||||
set(FXDIV_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FXdiv" CACHE STRING "FXdiv source directory")
|
||||
endif()
|
||||
if (NOT DEFINED PSIMD_SOURCE_DIR)
|
||||
set(PSIMD_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/psimd" CACHE STRING "PSimd source directory")
|
||||
endif()
|
||||
if (NOT DEFINED PTHREADPOOL_SOURCE_DIR)
|
||||
set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory")
|
||||
endif()
|
||||
|
||||
if(NOT TARGET qnnpack)
|
||||
set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
|
||||
set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
|
||||
set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
|
||||
add_subdirectory(
|
||||
"${QNNPACK_SOURCE_DIR}"
|
||||
"${CONFU_DEPENDENCIES_BINARY_DIR}/QNNPACK")
|
||||
# We build static versions of QNNPACK and pthreadpool but link
|
||||
# them into a shared library for Caffe2, so they need PIC.
|
||||
set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS qnnpack)
|
||||
if(NOT TARGET qnnpack)
|
||||
set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
|
||||
set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
|
||||
set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
add_subdirectory(
|
||||
"${QNNPACK_SOURCE_DIR}"
|
||||
"${CONFU_DEPENDENCIES_BINARY_DIR}/QNNPACK")
|
||||
# We build static versions of QNNPACK and pthreadpool but link
|
||||
# them into a shared library for Caffe2, so they need PIC.
|
||||
set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS qnnpack)
|
||||
endif()
|
||||
|
||||
# ---[ Caffe2 Int8 operators (enabled by USE_QNNPACK) depend on gemmlowp and neon2sse headers
|
||||
@ -303,63 +324,26 @@ endif()
|
||||
|
||||
# ---[ PYTORCH_QNNPACK
|
||||
if(USE_PYTORCH_QNNPACK)
|
||||
if (IOS)
|
||||
list(LENGTH IOS_ARCH IOS_ARCH_COUNT)
|
||||
if (IOS_ARCH_COUNT GREATER 1)
|
||||
message(WARNING
|
||||
"Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. "
|
||||
"Specify a single architecture in IOS_ARCH and re-configure, or "
|
||||
"turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
|
||||
set(USE_PYTORCH_QNNPACK OFF)
|
||||
if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR)
|
||||
set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory")
|
||||
endif()
|
||||
if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$")
|
||||
message(WARNING
|
||||
"Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
|
||||
set(USE_PYTORCH_QNNPACK OFF)
|
||||
endif()
|
||||
else()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$"))
|
||||
message(WARNING
|
||||
"Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. "
|
||||
"Supported platforms are Android, iOS, Linux, and macOS. "
|
||||
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
|
||||
set(USE_PYTORCH_QNNPACK OFF)
|
||||
endif()
|
||||
if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$"))
|
||||
message(WARNING
|
||||
"Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. "
|
||||
"Supported architectures are x86, x86-64, ARM, and ARM64. "
|
||||
"Turn this warning off by USE_PYTORCH_QNNPACK=OFF.")
|
||||
set(USE_PYTORCH_QNNPACK OFF)
|
||||
endif()
|
||||
endif()
|
||||
if (USE_PYTORCH_QNNPACK)
|
||||
if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR)
|
||||
set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory")
|
||||
endif()
|
||||
|
||||
if(NOT TARGET pytorch_qnnpack)
|
||||
set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CPUINFO_LOG_LEVEL "error" CACHE STRING "")
|
||||
add_subdirectory(
|
||||
"${PYTORCH_QNNPACK_SOURCE_DIR}"
|
||||
"${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack")
|
||||
# We build static versions of QNNPACK and pthreadpool but link
|
||||
# them into a shared library for Caffe2, so they need PIC.
|
||||
set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
if(NOT TARGET pytorch_qnnpack)
|
||||
set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
|
||||
set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
add_subdirectory(
|
||||
"${PYTORCH_QNNPACK_SOURCE_DIR}"
|
||||
"${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack")
|
||||
# We build static versions of QNNPACK and pthreadpool but link
|
||||
# them into a shared library for Caffe2, so they need PIC.
|
||||
set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack)
|
||||
endif()
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack)
|
||||
endif()
|
||||
|
||||
# ---[ NNPACK
|
||||
@ -379,6 +363,33 @@ if(USE_NNPACK)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ XNNPACK
|
||||
if(USE_XNNPACK)
|
||||
if (NOT DEFINED XNNPACK_SOURCE_DIR)
|
||||
set(XNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/XNNPACK" CACHE STRING "XNNPACK source directory")
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED XNNPACK_INCLUDE_DIR)
|
||||
set(XNNPACK_INCLUDE_DIR "${XNNPACK_SOURCE_DIR}/include" CACHE STRING "XNNPACK include directory")
|
||||
endif()
|
||||
|
||||
if(NOT TARGET XNNPACK)
|
||||
set(XNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "")
|
||||
set(XNNPACK_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(XNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(XNNPACK_BUILD_TESTS OFF CACHE BOOL "")
|
||||
|
||||
add_subdirectory(
|
||||
"${XNNPACK_SOURCE_DIR}"
|
||||
"${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK")
|
||||
|
||||
set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS XNNPACK)
|
||||
endif()
|
||||
|
||||
# ---[ Caffe2 uses cpuinfo library in the thread pool
|
||||
if (NOT TARGET cpuinfo)
|
||||
if (NOT DEFINED CPUINFO_SOURCE_DIR)
|
||||
|
@ -64,6 +64,11 @@ if (NOT @BUILD_SHARED_LIBS@)
|
||||
list(APPEND TORCH_LIBRARIES ${PYTORCH_QNNPACK_LIBRARY})
|
||||
endif()
|
||||
|
||||
if (@USE_XNNPACK@)
|
||||
find_library(XNNPACK_LIBRARY XNNPACK PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY})
|
||||
endif()
|
||||
|
||||
if (@INTERN_USE_EIGEN_BLAS@)
|
||||
find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY})
|
||||
|
@ -11,11 +11,11 @@ option_parser = OptionParser.new do |opts|
|
||||
end.parse!
|
||||
puts "Current directory: #{Dir.pwd}"
|
||||
install_path = File.expand_path("../../../build_ios/install")
|
||||
if not Dir.exist? (install_path)
|
||||
if not Dir.exist? (install_path)
|
||||
raise "path doesn't exist:#{install_path}!"
|
||||
end
|
||||
xcodeproj_path = File.expand_path("../TestApp.xcodeproj")
|
||||
if not File.exist? (xcodeproj_path)
|
||||
if not File.exist? (xcodeproj_path)
|
||||
raise "path doesn't exist:#{xcodeproj_path}!"
|
||||
end
|
||||
puts "Setting up TestApp.xcodeproj..."
|
||||
@ -63,7 +63,7 @@ targets.each do |target|
|
||||
target.resources_build_phase.add_file_reference(config_file_ref, true)
|
||||
end
|
||||
puts "Linking static libraries..."
|
||||
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
|
||||
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
|
||||
targets.each do |target|
|
||||
target.frameworks_build_phases.clear
|
||||
for lib in libs do
|
||||
|
@ -2,7 +2,7 @@ require 'optparse'
|
||||
require 'xcodeproj'
|
||||
|
||||
options = {}
|
||||
option_parser = OptionParser.new do |opts|
|
||||
option_parser = OptionParser.new do |opts|
|
||||
opts.banner = 'Tools for building PyTorch iOS framework on MacOS'
|
||||
opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value|
|
||||
options[:install] = value
|
||||
@ -23,11 +23,11 @@ end.parse!
|
||||
puts options.inspect
|
||||
|
||||
install_path = File.expand_path(options[:install])
|
||||
if not Dir.exist? (install_path)
|
||||
if not Dir.exist? (install_path)
|
||||
raise "path don't exist:#{install_path}!"
|
||||
end
|
||||
xcodeproj_path = File.expand_path(options[:xcodeproj])
|
||||
if not File.exist? (xcodeproj_path)
|
||||
if not File.exist? (xcodeproj_path)
|
||||
raise "path don't exist:#{xcodeproj_path}!"
|
||||
end
|
||||
|
||||
@ -51,8 +51,8 @@ end
|
||||
|
||||
# link static libraries
|
||||
target.frameworks_build_phases.clear
|
||||
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
|
||||
for lib in libs do
|
||||
libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
|
||||
for lib in libs do
|
||||
path = "#{install_path}/lib/#{lib}"
|
||||
if File.exist?(path)
|
||||
libref = project.frameworks_group.new_file(path)
|
||||
@ -68,12 +68,12 @@ elsif options[:platform] == 'OS'
|
||||
sdk = 'iphoneos'
|
||||
else
|
||||
raise "unsupported platform #{options[:platform]}"
|
||||
end
|
||||
end
|
||||
|
||||
profile = options[:profile]
|
||||
if not profile
|
||||
if not profile
|
||||
raise "no provisioning profile found!"
|
||||
end
|
||||
end
|
||||
|
||||
# run xcodebuild
|
||||
exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk #{sdk} -configuration Release PROVISIONING_PROFILE_SPECIFIER=#{profile}"
|
||||
|
1
third_party/XNNPACK
vendored
Submodule
1
third_party/XNNPACK
vendored
Submodule
Submodule third_party/XNNPACK added at fa611cc5c2
2
third_party/cpuinfo
vendored
2
third_party/cpuinfo
vendored
Submodule third_party/cpuinfo updated: 89fe1695ed...0e6bde92b3
2
third_party/psimd
vendored
2
third_party/psimd
vendored
Submodule third_party/psimd updated: 90a938f30b...10b4ffc6ea
2
third_party/pthreadpool
vendored
2
third_party/pthreadpool
vendored
Submodule third_party/pthreadpool updated: 13da0b4c21...d465747660
Reference in New Issue
Block a user