Files
pytorch/docs/source/notes/cpu_threading_torchscript_inference.rst
pinzhenx bd604cb5b7 Upgrade MKL-DNN to DNNL v1.2 (#32422)
Summary:
## Motivation

This PR upgrades MKL-DNN from v0.20 to DNNL v1.2 and resolves https://github.com/pytorch/pytorch/issues/30300.

DNNL (Deep Neural Network Library) is the new brand of MKL-DNN, which improves performance, quality, and usability over the old version.

This PR focuses on the migration of all existing functionalities, including minor fixes, performance improvement and code clean up. It serves as the cornerstone of our future efforts to accommodate new features like OpenCL support, BF16 training, INT8 inference, etc. and to let the Pytorch community derive more benefits from the Intel Architecture.

<br>

## What's included?

Even DNNL has many breaking changes to the API, we managed to absorb most of them in ideep. This PR contains minimalist changes to the integration code in pytorch. Below is a summary of the changes:

<br>

**General:**

1. Replace op-level allocator with global-registered allocator

```
// before
ideep::sum::compute<AllocForMKLDNN>(scales, {x, y}, z);

// after
ideep::sum::compute(scales, {x, y}, z);
```

The allocator is now being registeted at `aten/src/ATen/native/mkldnn/IDeepRegistration.cpp`. Thereafter all tensors derived from the `cpu_engine` (by default) will use the c10 allocator.

```
RegisterEngineAllocator cpu_alloc(
  ideep::engine::cpu_engine(),
  [](size_t size) {
    return c10::GetAllocator(c10::DeviceType::CPU)->raw_allocate(size);
  },
  [](void* p) {
    c10::GetAllocator(c10::DeviceType::CPU)->raw_deallocate(p);
  }
);
```
------

2. Simplify group convolution

We had such a scenario in convolution where ideep tensor shape mismatched aten tensor: when `groups > 1`, DNNL expects weights tensors to be 5-d with an extra group dimension, e.g. `goihw` instead of `oihw` in 2d conv case.

As shown below, a lot of extra checks came with this difference in shape before. Now we've completely hidden this difference in ideep and all tensors are going to align with pytorch's definition. So we could safely remove these checks from both aten and c2 integration code.

```
// aten/src/ATen/native/mkldnn/Conv.cpp

if (w.ndims() == x.ndims() + 1) {
  AT_ASSERTM(
      groups > 1,
      "Only group _mkldnn_conv2d weights could have been reordered to 5d");
  kernel_size[0] = w.get_dim(0) * w.get_dim(1);
  std::copy_n(
      w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1);
} else {
  std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin());
}
```

------

3. Enable DNNL built-in cache

Previously, we stored DNNL jitted kernels along with intermediate buffers inside ideep using an LRU cache. Now we are switching to the newly added DNNL built-in cache, and **no longer** caching buffers in order to reduce memory footprint.

This change will be mainly reflected in lower memory usage from memory profiling results. On the code side, we removed couple of lines of `op_key_` that depended on the ideep cache before.

------

4. Use 64-bit integer to denote dimensions

We changed the type of `ideep::dims` from `vector<int32_t>` to `vector<int64_t>`. This renders ideep dims no longer compatible with 32-bit dims used by caffe2. So we use something like `{stride_.begin(), stride_.end()}` to cast parameter `stride_` into a int64 vector.

<br>

**Misc changes in each commit:**

**Commit:** change build options

Some build options were slightly changed, mainly to avoid name collisions with other projects that include DNNL as a subproject. In addition, DNNL built-in cache is enabled by option `DNNL_ENABLE_PRIMITIVE_CACHE`.

Old | New
-- | --
WITH_EXAMPLE | MKLDNN_BUILD_EXAMPLES
WITH_TEST | MKLDNN_BUILD_TESTS
MKLDNN_THREADING | MKLDNN_CPU_RUNTIME
MKLDNN_USE_MKL | N/A (not use MKL anymore)

------

**Commit:** aten reintegration

- aten/src/ATen/native/mkldnn/BinaryOps.cpp

    Implement binary ops using new operation `binary` provided by DNNL

- aten/src/ATen/native/mkldnn/Conv.cpp

    Clean up group convolution checks
    Simplify conv backward integration

- aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp

    Simplify prepacking convolution weights

- test/test_mkldnn.py

    Fixed an issue in conv2d unit test: it didn't check conv results between mkldnn and aten implementation before. Instead, it compared the mkldnn with mkldnn as the default cpu path will also go into mkldnn. Now we use `torch.backends.mkldnn.flags` to fix this issue

- torch/utils/mkldnn.py

    Prepack weight tensor on module `__init__` to achieve better performance significantly

------

**Commit:** caffe2 reintegration

- caffe2/ideep/ideep_utils.h

    Clean up unused type definitions

- caffe2/ideep/operators/adam_op.cc & caffe2/ideep/operators/momentum_sgd_op.cc

   Unify tensor initialization with `ideep::tensor::init`. Obsolete `ideep::tensor::reinit`

- caffe2/ideep/operators/conv_op.cc & caffe2/ideep/operators/quantization/int8_conv_op.cc

    Clean up group convolution checks
    Revamp convolution API

- caffe2/ideep/operators/conv_transpose_op.cc

    Clean up group convolution checks
    Clean up deconv workaround code

------

**Commit:** custom allocator

- Register c10 allocator as mentioned above

<br><br>

## Performance

We tested inference on some common models based on user scenarios, and most performance numbers are either better than or on par with DNNL 0.20.

ratio: new / old | Latency (batch=1 4T) | Throughput (batch=64 56T)
-- | -- | --
pytorch resnet18 | 121.4% | 99.7%
pytorch resnet50 | 123.1% | 106.9%
pytorch resnext101_32x8d | 116.3% | 100.1%
pytorch resnext50_32x4d | 141.9% | 104.4%
pytorch mobilenet_v2 | 163.0% | 105.8%
caffe2 alexnet | 303.0% | 99.2%
caffe2 googlenet-v3 | 101.1% | 99.2%
caffe2 inception-v1 | 102.2% | 101.7%
caffe2 mobilenet-v1 | 356.1% | 253.7%
caffe2 resnet101 | 100.4% | 99.8%
caffe2 resnet152 | 99.8% | 99.8%
caffe2 shufflenet | 141.1% | 69.0% †
caffe2 squeezenet | 98.5% | 99.2%
caffe2 vgg16 | 136.8% | 100.6%
caffe2 googlenet-v3 int8 | 100.0% | 100.7%
caffe2 mobilenet-v1 int8 | 779.2% | 943.0%
caffe2 resnet50 int8 | 99.5% | 95.5%

_Configuration:
Platform: Skylake 8180
Latency Test: 4 threads, warmup 30, iteration 500, batch size 1
Throughput Test: 56 threads, warmup 30, iteration 200, batch size 64_

† Shufflenet is one of the few models that require temp buffers during inference. The performance degradation is an expected issue since we no longer cache any buffer in the ideep. As for the solution, we suggest users opt for caching allocator like **jemalloc** as a drop-in replacement for system allocator in such heavy workloads.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32422

Test Plan:
Perf results: https://our.intern.facebook.com/intern/fblearner/details/177790608?tab=Experiment%20Results

10% improvement for ResNext with avx512, neutral on avx2

More results: https://fb.quip.com/ob10AL0bCDXW#NNNACAUoHJP

Reviewed By: yinghai

Differential Revision: D20381325

Pulled By: dzhulgakov

fbshipit-source-id: 803b906fd89ed8b723c5fcab55039efe3e4bcb77
2020-03-26 22:07:59 -07:00

163 lines
9.5 KiB
ReStructuredText

.. _cpu-threading-torchscript-inference:
CPU threading and TorchScript inference
=================================================
PyTorch allows using multiple CPU threads during TorchScript model inference.
The following figure shows different levels of parallelism one would find in a
typical application:
.. image:: cpu_threading_torchscript_inference.svg
:width: 75%
One or more inference threads execute a model's forward pass on the given inputs.
Each inference thread invokes a JIT interpreter that executes the ops
of a model inline, one by one. A model can utilize a ``fork`` TorchScript
primitive to launch an asynchronous task. Forking several operations at once
results in a task that is executed in parallel. The ``fork`` operator returns a
``Future`` object which can be used to synchronize on later, for example:
.. code-block:: python
@torch.jit.script
def compute_z(x):
return torch.mm(x, self.w_z)
@torch.jit.script
def forward(x):
# launch compute_z asynchronously:
fut = torch.jit._fork(compute_z, x)
# execute the next operation in parallel to compute_z:
y = torch.mm(x, self.w_y)
# wait for the result of compute_z:
z = torch.jit._wait(fut)
return y + z
PyTorch uses a single thread pool for the inter-op parallelism, this thread pool
is shared by all inference tasks that are forked within the application process.
In addition to the inter-op parallelism, PyTorch can also utilize multiple threads
within the ops (`intra-op parallelism`). This can be useful in many cases,
including element-wise ops on large tensors, convolutions, GEMMs, embedding
lookups and others.
Build options
-------------
PyTorch uses an internal ATen library to implement ops. In addition to that,
PyTorch can also be built with support of external libraries, such as MKL_ and MKL-DNN_,
to speed up computations on CPU.
ATen, MKL and MKL-DNN support intra-op parallelism and depend on the
following parallelization libraries to implement it:
* OpenMP_ - a standard (and a library, usually shipped with a compiler), widely used in external libraries;
* TBB_ - a newer parallelization library optimized for task-based parallelism and concurrent environments.
OpenMP historically has been used by a large number of libraries. It is known
for a relative ease of use and support for loop-based parallelism and other primitives.
TBB is used to a lesser extent in external libraries, but, at the same time,
is optimized for the concurrent environments. PyTorch's TBB backend guarantees that
there's a separate, single, per-process intra-op thread pool used by all of the
ops running in the application.
Depending of the use case, one might find one or another parallelization
library a better choice in their application.
PyTorch allows selecting of the parallelization backend used by ATen and other
libraries at the build time with the following build options:
+------------+------------------------+-----------------------------+----------------------------------------+
| Library | Build Option | Values | Notes |
+============+========================+=============================+========================================+
| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | |
+------------+------------------------+-----------------------------+----------------------------------------+
| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` |
+------------+------------------------+-----------------------------+----------------------------------------+
| MKL-DNN | ``MKLDNN_CPU_RUNTIME`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` |
+------------+------------------------+-----------------------------+----------------------------------------+
It is recommended not to mix OpenMP and TBB within one build.
Any of the ``TBB`` values above require ``USE_TBB=1`` build setting (default: OFF).
A separate setting ``USE_OPENMP=1`` (default: ON) is required for OpenMP parallelism.
Runtime API
-----------
The following API is used to control thread settings:
+------------------------+-----------------------------------------------------------+---------------------------------------------------------+
| Type of parallelism | Settings | Notes |
+========================+===========================================================+=========================================================+
| Inter-op parallelism | ``at::set_num_interop_threads``, | Default number of threads: number of CPU cores. |
| | ``at::get_num_interop_threads`` (C++) | |
| | | |
| | ``set_num_interop_threads``, | |
| | ``get_num_interop_threads`` (Python, :mod:`torch` module) | |
+------------------------+-----------------------------------------------------------+ |
| Intra-op parallelism | ``at::set_num_threads``, | |
| | ``at::get_num_threads`` (C++) | |
| | ``set_num_threads``, | |
| | ``get_num_threads`` (Python, :mod:`torch` module) | |
| | | |
| | Environment variables: | |
| | ``OMP_NUM_THREADS`` and ``MKL_NUM_THREADS`` | |
+------------------------+-----------------------------------------------------------+---------------------------------------------------------+
For the intra-op parallelism settings, ``at::set_num_threads``, ``torch.set_num_threads`` always take precedence
over environment variables, ``MKL_NUM_THREADS`` variable takes precedence over ``OMP_NUM_THREADS``.
Tuning the number of threads
----------------------------
The following simple script shows how a runtime of matrix multiplication changes with the number of threads:
.. code-block:: python
import timeit
runtimes = []
threads = [1] + [t for t in range(2, 49, 2)]
for t in threads:
torch.set_num_threads(t)
r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100)
runtimes.append(r)
# ... plotting (threads, runtimes) ...
Running the script on a system with 24 physical CPU cores (Xeon E5-2680, MKL and OpenMP based build) results in the following runtimes:
.. image:: cpu_threading_runtimes.svg
:width: 75%
The following considerations should be taken into account when tuning the number of intra- and inter-op threads:
* When choosing the number of threads one needs to avoid `oversubscription` (using too many threads, leads to performance degradation). For example, in an application that uses a large application thread pool or heavily relies on
inter-op parallelism, one might find disabling intra-op parallelism as a possible option (i.e. by calling ``set_num_threads(1)``);
* In a typical application one might encounter a trade off between `latency` (time spent on processing an inference request) and `throughput` (amount of work done per unit of time). Tuning the number of threads can be a useful
tool to adjust this trade off in one way or another. For example, in latency critical applications one might want to increase the number of intra-op threads to process each request as fast as possible. At the same time, parallel implementations
of ops may add an extra overhead that increases amount work done per single request and thus reduces the overall throughput.
.. warning::
OpenMP does not guarantee that a single per-process intra-op thread
pool is going to be used in the application. On the contrary, two different application or inter-op
threads may use different OpenMP thread pools for intra-op work.
This might result in a large number of threads used by the application.
Extra care in tuning the number of threads is needed to avoid
oversubscription in multi-threaded applications in OpenMP case.
.. note::
Pre-built PyTorch releases are compiled with OpenMP support.
.. note::
``parallel_info`` utility prints information about thread settings and can be used for debugging.
Similar output can be also obtained in Python with ``torch.__config__.parallel_info()`` call.
.. _OpenMP: https://www.openmp.org/
.. _TBB: https://github.com/intel/tbb
.. _MKL: https://software.intel.com/en-us/mkl
.. _MKL-DNN: https://github.com/intel/mkl-dnn